-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[DAGCombiner] Honor rewrite semantics of fast-math flags in fdiv combine #167595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[DAGCombiner] Honor rewrite semantics of fast-math flags in fdiv combine #167595
Conversation
|
@llvm/pr-subscribers-backend-aarch64 @llvm/pr-subscribers-backend-x86 Author: Mikołaj Piróg (mikolaj-pirog) ChangesAs in title. Rewrite semantics, as defined here: https://llvm.org/docs/LangRef.html#floating-point-semantics, dictate that when a given transformation happens, all of instructions taking part in this transformation need to have appropriate flag present. In the case of this change I understand these semantics as following:
Patch is 22.01 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/167595.diff 13 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index f144f17d5a8f2..8b5f633b99dd1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -18613,6 +18613,8 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
EVT VT = N->getValueType(0);
SDLoc DL(N);
SDNodeFlags Flags = N->getFlags();
+ SDNodeFlags FlagsN1 = N1->getFlags();
+
SelectionDAG::FlagInserter FlagsInserter(DAG, N);
if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
@@ -18657,18 +18659,25 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
if (Flags.hasAllowReciprocal()) {
// If this FDIV is part of a reciprocal square root, it may be folded
// into a target-specific square root estimate instruction.
- if (N1.getOpcode() == ISD::FSQRT) {
+ // X / sqrt(Y) -> X * rsqrt(Y)
+ bool N1AllowReciprocal = FlagsN1.hasAllowReciprocal();
+ bool N1Op0AllowsReciprocal =
+ N1.getNumOperands() > 0 &&
+ N1.getOperand(0)->getFlags().hasAllowReciprocal();
+ if (N1.getOpcode() == ISD::FSQRT && N1AllowReciprocal) {
if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0)))
return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
} else if (N1.getOpcode() == ISD::FP_EXTEND &&
- N1.getOperand(0).getOpcode() == ISD::FSQRT) {
+ N1.getOperand(0).getOpcode() == ISD::FSQRT &&
+ N1Op0AllowsReciprocal && N1AllowReciprocal) {
if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0))) {
RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
AddToWorklist(RV.getNode());
return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
}
} else if (N1.getOpcode() == ISD::FP_ROUND &&
- N1.getOperand(0).getOpcode() == ISD::FSQRT) {
+ N1.getOperand(0).getOpcode() == ISD::FSQRT &&
+ N1Op0AllowsReciprocal && N1AllowReciprocal) {
if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0))) {
RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
AddToWorklist(RV.getNode());
@@ -18688,8 +18697,10 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
if (Sqrt.getNode()) {
// If the other multiply operand is known positive, pull it into the
// sqrt. That will eliminate the division if we convert to an estimate.
- if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
- N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
+ if (N1.hasOneUse() && Sqrt.hasOneUse() &&
+ Sqrt->getFlags().hasAllowReciprocal() &&
+ Sqrt->getFlags().hasAllowReassociation() &&
+ FlagsN1.hasAllowReciprocal() && FlagsN1.hasAllowReassociation()) {
SDValue A;
if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
A = Y.getOperand(0);
@@ -18711,7 +18722,10 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
// We found a FSQRT, so try to make this fold:
// X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
- if (SDValue Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0))) {
+ SDValue Rsqrt;
+ if (N1AllowReciprocal && Sqrt->getFlags().hasAllowReciprocal() &&
+ (Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0)))) {
+ Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0));
SDValue Div = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, Rsqrt, Y);
AddToWorklist(Div.getNode());
return DAG.getNode(ISD::FMUL, DL, VT, N0, Div);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 88b0809b767b5..ecc26029ee152 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -3976,7 +3976,10 @@ void SelectionDAGBuilder::visitFPExt(const User &I) {
SDValue N = getValue(I.getOperand(0));
EVT DestVT = DAG.getTargetLoweringInfo().getValueType(DAG.getDataLayout(),
I.getType());
- setValue(&I, DAG.getNode(ISD::FP_EXTEND, getCurSDLoc(), DestVT, N));
+ SDNodeFlags Flags;
+ if (auto *TruncInst = dyn_cast<FPMathOperator>(&I))
+ Flags.copyFMF(*TruncInst);
+ setValue(&I, DAG.getNode(ISD::FP_EXTEND, getCurSDLoc(), DestVT, N, Flags));
}
void SelectionDAGBuilder::visitFPToUI(const User &I) {
diff --git a/llvm/test/CodeGen/AArch64/sqrt-fastmath.ll b/llvm/test/CodeGen/AArch64/sqrt-fastmath.ll
index e29993d02935a..737b3d903ed0f 100644
--- a/llvm/test/CodeGen/AArch64/sqrt-fastmath.ll
+++ b/llvm/test/CodeGen/AArch64/sqrt-fastmath.ll
@@ -490,7 +490,7 @@ define <2 x double> @sqrt_fdiv_common_operand_vec(<2 x double> %x) nounwind {
; CHECK-NEXT: fmul v0.2d, v0.2d, v1.2d
; CHECK-NEXT: fmul v0.2d, v0.2d, v2.2d
; CHECK-NEXT: ret
- %sqrt = call <2 x double> @llvm.sqrt.v2f64(<2 x double> %x)
+ %sqrt = call arcp <2 x double> @llvm.sqrt.v2f64(<2 x double> %x)
%r = fdiv arcp nsz reassoc <2 x double> %x, %sqrt
ret <2 x double> %r
}
diff --git a/llvm/test/CodeGen/AMDGPU/fdiv_flags.f32.ll b/llvm/test/CodeGen/AMDGPU/fdiv_flags.f32.ll
index 38239c5509318..cf21a0ca1c47b 100644
--- a/llvm/test/CodeGen/AMDGPU/fdiv_flags.f32.ll
+++ b/llvm/test/CodeGen/AMDGPU/fdiv_flags.f32.ll
@@ -981,7 +981,7 @@ define float @v_fdiv_recip_sqrt_f32_arcp_fdiv_only(float %x) {
; IR-DAZ-GISEL-NEXT: v_div_fmas_f32 v1, v1, v2, v4
; IR-DAZ-GISEL-NEXT: v_div_fixup_f32 v0, v1, v0, 1.0
; IR-DAZ-GISEL-NEXT: s_setpc_b64 s[30:31]
- %sqrt = call float @llvm.sqrt.f32(float %x)
+ %sqrt = call arcp float @llvm.sqrt.f32(float %x)
%fdiv = fdiv arcp float 1.0, %sqrt
ret float %fdiv
}
@@ -1297,7 +1297,7 @@ define float @v_fdiv_recip_sqrt_f32_arcp_afn_fdiv_only(float %x) {
; IR-DAZ-GISEL-NEXT: v_cndmask_b32_e32 v0, v1, v0, vcc
; IR-DAZ-GISEL-NEXT: v_rcp_f32_e32 v0, v0
; IR-DAZ-GISEL-NEXT: s_setpc_b64 s[30:31]
- %sqrt = call float @llvm.sqrt.f32(float %x)
+ %sqrt = call arcp float @llvm.sqrt.f32(float %x)
%fdiv = fdiv arcp afn float 1.0, %sqrt
ret float %fdiv
}
diff --git a/llvm/test/CodeGen/AMDGPU/fsqrt.r600.ll b/llvm/test/CodeGen/AMDGPU/fsqrt.r600.ll
index c93c077706046..d1cb9632bc9fe 100644
--- a/llvm/test/CodeGen/AMDGPU/fsqrt.r600.ll
+++ b/llvm/test/CodeGen/AMDGPU/fsqrt.r600.ll
@@ -228,7 +228,7 @@ define amdgpu_kernel void @recip_sqrt(ptr addrspace(1) %out, float %src) nounwin
; R600-NEXT: LSHR T0.X, KC0[2].Y, literal.x,
; R600-NEXT: RECIPSQRT_IEEE * T1.X, KC0[2].Z,
; R600-NEXT: 2(2.802597e-45), 0(0.000000e+00)
- %sqrt = call float @llvm.sqrt.f32(float %src)
+ %sqrt = call arcp float @llvm.sqrt.f32(float %src)
%recipsqrt = fdiv fast float 1.0, %sqrt
store float %recipsqrt, ptr addrspace(1) %out, align 4
ret void
diff --git a/llvm/test/CodeGen/AMDGPU/rsq.f32.ll b/llvm/test/CodeGen/AMDGPU/rsq.f32.ll
index f967e951b27a4..03e258fc84d61 100644
--- a/llvm/test/CodeGen/AMDGPU/rsq.f32.ll
+++ b/llvm/test/CodeGen/AMDGPU/rsq.f32.ll
@@ -194,8 +194,8 @@ define amdgpu_kernel void @rsqrt_fmul(ptr addrspace(1) %out, ptr addrspace(1) %i
%b = load volatile float, ptr addrspace(1) %gep.1
%c = load volatile float, ptr addrspace(1) %gep.2
- %x = call contract float @llvm.sqrt.f32(float %a)
- %y = fmul contract float %x, %b
+ %x = call arcp contract float @llvm.sqrt.f32(float %a)
+ %y = fmul arcp contract float %x, %b
%z = fdiv arcp afn contract float %c, %y
store float %z, ptr addrspace(1) %out.gep
ret void
@@ -756,7 +756,7 @@ define { float, float } @v_rsq_f32_multi_use(float %val) {
; CI-IEEE-SAFE-NEXT: v_sub_i32_e32 v2, vcc, 0, v2
; CI-IEEE-SAFE-NEXT: v_ldexp_f32_e32 v1, v1, v2
; CI-IEEE-SAFE-NEXT: s_setpc_b64 s[30:31]
- %sqrt = call afn contract float @llvm.sqrt.f32(float %val), !fpmath !1
+ %sqrt = call arcp afn contract float @llvm.sqrt.f32(float %val), !fpmath !1
%insert.0 = insertvalue { float, float } poison, float %sqrt, 0
%div = fdiv arcp afn contract float 1.0, %sqrt, !fpmath !1
%insert.1 = insertvalue { float, float } %insert.0, float %div, 1
@@ -838,7 +838,7 @@ define float @v_rsq_f32_missing_contract0(float %val) {
; CI-IEEE-SAFE-NEXT: v_sub_i32_e32 v0, vcc, 0, v0
; CI-IEEE-SAFE-NEXT: v_ldexp_f32_e32 v0, v1, v0
; CI-IEEE-SAFE-NEXT: s_setpc_b64 s[30:31]
- %sqrt = call afn float @llvm.sqrt.f32(float %val), !fpmath !1
+ %sqrt = call arcp afn float @llvm.sqrt.f32(float %val), !fpmath !1
%div = fdiv arcp afn contract float 1.0, %sqrt, !fpmath !1
ret float %div
}
@@ -855,7 +855,7 @@ define float @v_rsq_f32_missing_contract1(float %val) {
; GCN-IEEE-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GCN-IEEE-NEXT: v_rsq_f32_e32 v0, v0
; GCN-IEEE-NEXT: s_setpc_b64 s[30:31]
- %sqrt = call afn contract float @llvm.sqrt.f32(float %val), !fpmath !1
+ %sqrt = call arcp afn contract float @llvm.sqrt.f32(float %val), !fpmath !1
%div = fdiv arcp afn float 1.0, %sqrt, !fpmath !1
ret float %div
}
diff --git a/llvm/test/CodeGen/NVPTX/sqrt-approx.ll b/llvm/test/CodeGen/NVPTX/sqrt-approx.ll
index 7e4e701af4cd1..59012fd18cb5e 100644
--- a/llvm/test/CodeGen/NVPTX/sqrt-approx.ll
+++ b/llvm/test/CodeGen/NVPTX/sqrt-approx.ll
@@ -53,7 +53,7 @@ define double @test_rsqrt64(double %a) {
; CHECK-NEXT: rsqrt.approx.f64 %rd2, %rd1;
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
; CHECK-NEXT: ret;
- %val = tail call double @llvm.sqrt.f64(double %a)
+ %val = tail call arcp double @llvm.sqrt.f64(double %a)
%ret = fdiv arcp double 1.0, %val
ret double %ret
}
@@ -69,7 +69,7 @@ define double @test_rsqrt64_ftz(double %a) #1 {
; CHECK-NEXT: rsqrt.approx.f64 %rd2, %rd1;
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
; CHECK-NEXT: ret;
- %val = tail call double @llvm.sqrt.f64(double %a)
+ %val = tail call arcp double @llvm.sqrt.f64(double %a)
%ret = fdiv arcp double 1.0, %val
ret double %ret
}
@@ -228,7 +228,7 @@ define float @test_rsqrt32_refined(float %a) #2 {
; CHECK-NEXT: mul.f32 %r6, %r5, %r4;
; CHECK-NEXT: st.param.b32 [func_retval0], %r6;
; CHECK-NEXT: ret;
- %val = tail call float @llvm.sqrt.f32(float %a)
+ %val = tail call arcp float @llvm.sqrt.f32(float %a)
%ret = fdiv arcp contract float 1.0, %val
ret float %ret
}
@@ -283,7 +283,7 @@ define double @test_rsqrt64_refined(double %a) #2 {
; CHECK-NEXT: mul.f64 %rd6, %rd5, %rd4;
; CHECK-NEXT: st.param.b64 [func_retval0], %rd6;
; CHECK-NEXT: ret;
- %val = tail call double @llvm.sqrt.f64(double %a)
+ %val = tail call arcp double @llvm.sqrt.f64(double %a)
%ret = fdiv arcp contract double 1.0, %val
ret double %ret
}
@@ -340,7 +340,7 @@ define float @test_rsqrt32_refined_ftz(float %a) #1 #2 {
; CHECK-NEXT: mul.ftz.f32 %r6, %r5, %r4;
; CHECK-NEXT: st.param.b32 [func_retval0], %r6;
; CHECK-NEXT: ret;
- %val = tail call float @llvm.sqrt.f32(float %a)
+ %val = tail call arcp float @llvm.sqrt.f32(float %a)
%ret = fdiv arcp contract float 1.0, %val
ret float %ret
}
@@ -395,7 +395,7 @@ define double @test_rsqrt64_refined_ftz(double %a) #1 #2 {
; CHECK-NEXT: mul.f64 %rd6, %rd5, %rd4;
; CHECK-NEXT: st.param.b64 [func_retval0], %rd6;
; CHECK-NEXT: ret;
- %val = tail call double @llvm.sqrt.f64(double %a)
+ %val = tail call arcp double @llvm.sqrt.f64(double %a)
%ret = fdiv arcp contract double 1.0, %val
ret double %ret
}
diff --git a/llvm/test/CodeGen/PowerPC/recipest.ll b/llvm/test/CodeGen/PowerPC/recipest.ll
index 4bf572bb02942..55001458d1715 100644
--- a/llvm/test/CodeGen/PowerPC/recipest.ll
+++ b/llvm/test/CodeGen/PowerPC/recipest.ll
@@ -163,8 +163,8 @@ define double @foof_fmf(double %a, float %b) nounwind {
; CHECK-P9-NEXT: xsmulsp f0, f0, f3
; CHECK-P9-NEXT: xsmuldp f1, f1, f0
; CHECK-P9-NEXT: blr
- %x = call contract reassoc arcp float @llvm.sqrt.f32(float %b)
- %y = fpext float %x to double
+ %x = tail call contract reassoc arcp float @llvm.sqrt.f32(float %b)
+ %y = fpext reassoc arcp float %x to double
%r = fdiv contract reassoc arcp double %a, %y
ret double %r
}
@@ -188,7 +188,7 @@ define double @foof_safe(double %a, float %b) nounwind {
; CHECK-P9-NEXT: xsdivdp f1, f1, f0
; CHECK-P9-NEXT: blr
%x = call float @llvm.sqrt.f32(float %b)
- %y = fpext float %x to double
+ %y = fpext arcp float %x to double
%r = fdiv double %a, %y
ret double %r
}
@@ -253,7 +253,7 @@ define float @food_fmf(float %a, double %b) nounwind {
; CHECK-P9-NEXT: xsmulsp f1, f1, f0
; CHECK-P9-NEXT: blr
%x = call contract reassoc arcp double @llvm.sqrt.f64(double %b)
- %y = fptrunc double %x to float
+ %y = fptrunc arcp double %x to float
%r = fdiv contract reassoc arcp float %a, %y
ret float %r
}
@@ -280,7 +280,7 @@ define float @food_safe(float %a, double %b) nounwind {
; CHECK-P9-NEXT: xsdivsp f1, f1, f0
; CHECK-P9-NEXT: blr
%x = call double @llvm.sqrt.f64(double %b)
- %y = fptrunc double %x to float
+ %y = fptrunc arcp double %x to float
%r = fdiv float %a, %y
ret float %r
}
@@ -433,7 +433,7 @@ define float @rsqrt_fmul_fmf(float %a, float %b, float %c) {
; CHECK-P9-NEXT: xsmulsp f1, f3, f4
; CHECK-P9-NEXT: blr
%x = call contract reassoc arcp nsz float @llvm.sqrt.f32(float %a)
- %y = fmul contract reassoc nsz float %x, %b
+ %y = fmul contract reassoc arcp nsz float %x, %b
%z = fdiv contract reassoc arcp nsz ninf float %c, %y
ret float %z
}
diff --git a/llvm/test/CodeGen/PowerPC/vsx-fma-mutate-trivial-copy.ll b/llvm/test/CodeGen/PowerPC/vsx-fma-mutate-trivial-copy.ll
index 539b563691723..96a142d0a5634 100644
--- a/llvm/test/CodeGen/PowerPC/vsx-fma-mutate-trivial-copy.ll
+++ b/llvm/test/CodeGen/PowerPC/vsx-fma-mutate-trivial-copy.ll
@@ -17,7 +17,7 @@ for.body: ; preds = %for.body, %for.body
%div = fdiv reassoc arcp float 0.000000e+00, %W
%add = fadd reassoc float %div, %d_min
%conv2 = fpext float %add to double
- %0 = tail call double @llvm.sqrt.f64(double %conv2)
+ %0 = tail call arcp double @llvm.sqrt.f64(double %conv2)
%div4 = fdiv reassoc arcp double %conv3, %0
%call = tail call signext i32 @p_col_helper(double %div4) #2
br label %for.body
diff --git a/llvm/test/CodeGen/PowerPC/vsx-recip-est.ll b/llvm/test/CodeGen/PowerPC/vsx-recip-est.ll
index 4b9d17c26d012..245c8e06eba36 100644
--- a/llvm/test/CodeGen/PowerPC/vsx-recip-est.ll
+++ b/llvm/test/CodeGen/PowerPC/vsx-recip-est.ll
@@ -23,7 +23,7 @@ entry:
store float %f, ptr %f.addr, align 4
%0 = load float, ptr %f.addr, align 4
%1 = load float, ptr @b, align 4
- %2 = call float @llvm.sqrt.f32(float %1)
+ %2 = call arcp float @llvm.sqrt.f32(float %1)
%div = fdiv arcp float %0, %2
ret float %div
; CHECK-LABEL: @emit_xsrsqrtesp
@@ -51,7 +51,7 @@ entry:
store double %f, ptr %f.addr, align 8
%0 = load double, ptr %f.addr, align 8
%1 = load double, ptr @d, align 8
- %2 = call double @llvm.sqrt.f64(double %1)
+ %2 = call arcp double @llvm.sqrt.f64(double %1)
%div = fdiv arcp double %0, %2
ret double %div
; CHECK-LABEL: @emit_xsrsqrtedp
diff --git a/llvm/test/CodeGen/X86/fmf-flags.ll b/llvm/test/CodeGen/X86/fmf-flags.ll
index 16ebf70126f8b..602ed79b7a919 100644
--- a/llvm/test/CodeGen/X86/fmf-flags.ll
+++ b/llvm/test/CodeGen/X86/fmf-flags.ll
@@ -99,7 +99,7 @@ define dso_local float @not_so_fast_recip_sqrt(float %x) {
; X86-NEXT: fxch %st(1)
; X86-NEXT: fstps sqrt1
; X86-NEXT: retl
- %y = call float @llvm.sqrt.f32(float %x)
+ %y = call arcp float @llvm.sqrt.f32(float %x)
%z = fdiv fast float 1.0, %y
store float %y, ptr @sqrt1, align 4
%ret = fadd float %z , 14.5
diff --git a/llvm/test/CodeGen/X86/sqrt-fastmath-mir.ll b/llvm/test/CodeGen/X86/sqrt-fastmath-mir.ll
index 18588aada145c..b3359126837e1 100644
--- a/llvm/test/CodeGen/X86/sqrt-fastmath-mir.ll
+++ b/llvm/test/CodeGen/X86/sqrt-fastmath-mir.ll
@@ -114,7 +114,7 @@ define float @rsqrt_ieee(float %f) #0 {
; CHECK-NEXT: [[VMULSSrr5:%[0-9]+]]:fr32 = nnan ninf nsz arcp contract afn reassoc nofpexcept VMULSSrr killed [[VMULSSrr4]], killed [[VFMADD213SSr1]], implicit $mxcsr
; CHECK-NEXT: $xmm0 = COPY [[VMULSSrr5]]
; CHECK-NEXT: RET 0, $xmm0
- %sqrt = tail call float @llvm.sqrt.f32(float %f)
+ %sqrt = tail call arcp float @llvm.sqrt.f32(float %f)
%div = fdiv fast float 1.0, %sqrt
ret float %div
}
@@ -139,7 +139,7 @@ define float @rsqrt_daz(float %f) #1 {
; CHECK-NEXT: [[VMULSSrr5:%[0-9]+]]:fr32 = nnan ninf nsz arcp contract afn reassoc nofpexcept VMULSSrr killed [[VMULSSrr4]], killed [[VFMADD213SSr1]], implicit $mxcsr
; CHECK-NEXT: $xmm0 = COPY [[VMULSSrr5]]
; CHECK-NEXT: RET 0, $xmm0
- %sqrt = tail call float @llvm.sqrt.f32(float %f)
+ %sqrt = tail call arcp float @llvm.sqrt.f32(float %f)
%div = fdiv fast float 1.0, %sqrt
ret float %div
}
diff --git a/llvm/test/CodeGen/X86/sqrt-fastmath.ll b/llvm/test/CodeGen/X86/sqrt-fastmath.ll
index 83bfcd7f04d9c..55b0f5d9f8e77 100644
--- a/llvm/test/CodeGen/X86/sqrt-fastmath.ll
+++ b/llvm/test/CodeGen/X86/sqrt-fastmath.ll
@@ -333,7 +333,7 @@ define float @f32_estimate(float %x) #1 {
; AVX512-NEXT: vmulss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
; AVX512-NEXT: vmulss %xmm0, %xmm1, %xmm0
; AVX512-NEXT: retq
- %sqrt = tail call float @llvm.sqrt.f32(float %x)
+ %sqrt = tail call fast float @llvm.sqrt.f32(float %x)
%div = fdiv fast float 1.0, %sqrt
ret float %div
}
@@ -366,7 +366,7 @@ define <4 x float> @v4f32_no_estimate(<4 x float> %x) #0 {
; AVX-NEXT: vbroadcastss {{.*#+}} xmm1 = [1.0E+0,1.0E+0,1.0E+0,1.0E+0]
; AVX-NEXT: vdivps %xmm0, %xmm1, %xmm0
; AVX-NEXT: retq
- %sqrt = tail call <4 x float> @llvm.sqrt.v4f32(<4 x float> %x)
+ %sqrt = tail call fast <4 x float> @llvm.sqrt.v4f32(<4 x float> %x)
%div = fdiv fast <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt
ret <4 x float> %div
}
@@ -402,7 +402,7 @@ define <4 x float> @v4f32_estimate(<4 x float> %x) #1 {
; AVX512-NEXT: vmulps %xmm0, %xmm1, %xmm0
; AVX512-NEXT: vmulps %xmm2, %xmm0, %xmm0
; AVX512-NEXT: retq
- %sqrt = tail call <4 x float> @llvm.sqrt.v4f32(<4 x float> %x)
+ %sqrt = tail call fast <4 x float> @llvm.sqrt.v4f32(<4 x float> %x)
%div = fdiv fast <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt
ret <4 x float> %div
}
@@ -467,7 +467,7 @@ define <8 x float> @v8f32_no_estimate(<8 x float> %x) #0 {
; AVX512-NEXT: vbroadcastss {{.*#+}} ymm1 = [1.0E+0,1.0E+0,1.0E+0,1.0E+0,1.0E+0,1.0E+0,1.0E+0,1.0E+0]
; AVX512-NEXT: vdivps %ymm0, %ymm1, %ymm0
; AVX512-NEXT: retq
- %sqrt = tail call <8 x float> @llvm.sqrt.v8f32(<8 x float> %x)
+ %sqrt = tail call fast <8 x float> @llvm.sqrt.v8f32(<8 x float> %x)
%div = fdiv fast <8 x float> <float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt
ret <8 x float> %div
}
@@ -511,7 +511,7 @@ define <8 x float> @v8f32_estimate(<8 x float> %x) #1 {
; AVX512-NEXT: vmulps %ymm3, %ymm1, %ymm0
; AVX512-NEXT: vmulps %ymm2, %ymm0, %ymm0
; AVX512-NEXT: retq
- %sqrt = tail call <8 x float> @llvm.sqrt.v8f32(<8 x float> %x)
+ %sqrt = tail call fast <8 x float> @llvm.sqrt.v8f32(<8 x float> %x)
%div = fdiv fast <8 x float> <float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt
ret <8 x float> %div
}
@@ -610,7 +610,7 @@ define <16 x float> @v16f32_estimate(<16 x float> %x) #1 {
; AVX512-NEXT: vmulps {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to16}, %zmm1, %zmm1
; AVX512-NEXT: vmulps %zmm0, %zmm1, %zmm0
; AVX512-NEXT: retq
- %sqrt = tail call <16 x float> @llvm.sqrt.v16f32(<16 x float> %x)
+ %sqrt = tail call fast <16 x float> @llvm.sqrt.v16f32(<16 x float> %x)
%div = fdiv fast <16 x float> <float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt
ret <16 x float> %div
}
@@ -656,8 +656,8 @@ define float @div_sqrt_fabs_f32(float %x, float %y, float %z) {
; AVX512-NEXT: vmulss %xmm2, %xmm0, %xmm0
; AVX512-NEXT: vmulss %xmm1, %xmm0, %xmm0
; AVX512-NEXT: retq
- %s = call fast float @llvm.sqrt.f32(float %z)
- %a = call fast float @llvm.fabs.f32(float %y)
+ %s = call arcp reassoc float @llvm.sqrt.f32(float %z)
+ %a = call float @llvm.fabs.f32(float %y)
%m = fmul fast float %s, ...
[truncated]
|
|
@llvm/pr-subscribers-backend-nvptx Author: Mikołaj Piróg (mikolaj-pirog) ChangesAs in title. Rewrite semantics, as defined here: https://llvm.org/docs/LangRef.html#floating-point-semantics, dictate that when a given transformation happens, all of instructions taking part in this transformation need to have appropriate flag present. In the case of this change I understand these semantics as following:
Patch is 22.01 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/167595.diff 13 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index f144f17d5a8f2..8b5f633b99dd1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -18613,6 +18613,8 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
EVT VT = N->getValueType(0);
SDLoc DL(N);
SDNodeFlags Flags = N->getFlags();
+ SDNodeFlags FlagsN1 = N1->getFlags();
+
SelectionDAG::FlagInserter FlagsInserter(DAG, N);
if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
@@ -18657,18 +18659,25 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
if (Flags.hasAllowReciprocal()) {
// If this FDIV is part of a reciprocal square root, it may be folded
// into a target-specific square root estimate instruction.
- if (N1.getOpcode() == ISD::FSQRT) {
+ // X / sqrt(Y) -> X * rsqrt(Y)
+ bool N1AllowReciprocal = FlagsN1.hasAllowReciprocal();
+ bool N1Op0AllowsReciprocal =
+ N1.getNumOperands() > 0 &&
+ N1.getOperand(0)->getFlags().hasAllowReciprocal();
+ if (N1.getOpcode() == ISD::FSQRT && N1AllowReciprocal) {
if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0)))
return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
} else if (N1.getOpcode() == ISD::FP_EXTEND &&
- N1.getOperand(0).getOpcode() == ISD::FSQRT) {
+ N1.getOperand(0).getOpcode() == ISD::FSQRT &&
+ N1Op0AllowsReciprocal && N1AllowReciprocal) {
if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0))) {
RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
AddToWorklist(RV.getNode());
return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
}
} else if (N1.getOpcode() == ISD::FP_ROUND &&
- N1.getOperand(0).getOpcode() == ISD::FSQRT) {
+ N1.getOperand(0).getOpcode() == ISD::FSQRT &&
+ N1Op0AllowsReciprocal && N1AllowReciprocal) {
if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0))) {
RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
AddToWorklist(RV.getNode());
@@ -18688,8 +18697,10 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
if (Sqrt.getNode()) {
// If the other multiply operand is known positive, pull it into the
// sqrt. That will eliminate the division if we convert to an estimate.
- if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
- N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
+ if (N1.hasOneUse() && Sqrt.hasOneUse() &&
+ Sqrt->getFlags().hasAllowReciprocal() &&
+ Sqrt->getFlags().hasAllowReassociation() &&
+ FlagsN1.hasAllowReciprocal() && FlagsN1.hasAllowReassociation()) {
SDValue A;
if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
A = Y.getOperand(0);
@@ -18711,7 +18722,10 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
// We found a FSQRT, so try to make this fold:
// X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
- if (SDValue Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0))) {
+ SDValue Rsqrt;
+ if (N1AllowReciprocal && Sqrt->getFlags().hasAllowReciprocal() &&
+ (Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0)))) {
+ Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0));
SDValue Div = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, Rsqrt, Y);
AddToWorklist(Div.getNode());
return DAG.getNode(ISD::FMUL, DL, VT, N0, Div);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 88b0809b767b5..ecc26029ee152 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -3976,7 +3976,10 @@ void SelectionDAGBuilder::visitFPExt(const User &I) {
SDValue N = getValue(I.getOperand(0));
EVT DestVT = DAG.getTargetLoweringInfo().getValueType(DAG.getDataLayout(),
I.getType());
- setValue(&I, DAG.getNode(ISD::FP_EXTEND, getCurSDLoc(), DestVT, N));
+ SDNodeFlags Flags;
+ if (auto *TruncInst = dyn_cast<FPMathOperator>(&I))
+ Flags.copyFMF(*TruncInst);
+ setValue(&I, DAG.getNode(ISD::FP_EXTEND, getCurSDLoc(), DestVT, N, Flags));
}
void SelectionDAGBuilder::visitFPToUI(const User &I) {
diff --git a/llvm/test/CodeGen/AArch64/sqrt-fastmath.ll b/llvm/test/CodeGen/AArch64/sqrt-fastmath.ll
index e29993d02935a..737b3d903ed0f 100644
--- a/llvm/test/CodeGen/AArch64/sqrt-fastmath.ll
+++ b/llvm/test/CodeGen/AArch64/sqrt-fastmath.ll
@@ -490,7 +490,7 @@ define <2 x double> @sqrt_fdiv_common_operand_vec(<2 x double> %x) nounwind {
; CHECK-NEXT: fmul v0.2d, v0.2d, v1.2d
; CHECK-NEXT: fmul v0.2d, v0.2d, v2.2d
; CHECK-NEXT: ret
- %sqrt = call <2 x double> @llvm.sqrt.v2f64(<2 x double> %x)
+ %sqrt = call arcp <2 x double> @llvm.sqrt.v2f64(<2 x double> %x)
%r = fdiv arcp nsz reassoc <2 x double> %x, %sqrt
ret <2 x double> %r
}
diff --git a/llvm/test/CodeGen/AMDGPU/fdiv_flags.f32.ll b/llvm/test/CodeGen/AMDGPU/fdiv_flags.f32.ll
index 38239c5509318..cf21a0ca1c47b 100644
--- a/llvm/test/CodeGen/AMDGPU/fdiv_flags.f32.ll
+++ b/llvm/test/CodeGen/AMDGPU/fdiv_flags.f32.ll
@@ -981,7 +981,7 @@ define float @v_fdiv_recip_sqrt_f32_arcp_fdiv_only(float %x) {
; IR-DAZ-GISEL-NEXT: v_div_fmas_f32 v1, v1, v2, v4
; IR-DAZ-GISEL-NEXT: v_div_fixup_f32 v0, v1, v0, 1.0
; IR-DAZ-GISEL-NEXT: s_setpc_b64 s[30:31]
- %sqrt = call float @llvm.sqrt.f32(float %x)
+ %sqrt = call arcp float @llvm.sqrt.f32(float %x)
%fdiv = fdiv arcp float 1.0, %sqrt
ret float %fdiv
}
@@ -1297,7 +1297,7 @@ define float @v_fdiv_recip_sqrt_f32_arcp_afn_fdiv_only(float %x) {
; IR-DAZ-GISEL-NEXT: v_cndmask_b32_e32 v0, v1, v0, vcc
; IR-DAZ-GISEL-NEXT: v_rcp_f32_e32 v0, v0
; IR-DAZ-GISEL-NEXT: s_setpc_b64 s[30:31]
- %sqrt = call float @llvm.sqrt.f32(float %x)
+ %sqrt = call arcp float @llvm.sqrt.f32(float %x)
%fdiv = fdiv arcp afn float 1.0, %sqrt
ret float %fdiv
}
diff --git a/llvm/test/CodeGen/AMDGPU/fsqrt.r600.ll b/llvm/test/CodeGen/AMDGPU/fsqrt.r600.ll
index c93c077706046..d1cb9632bc9fe 100644
--- a/llvm/test/CodeGen/AMDGPU/fsqrt.r600.ll
+++ b/llvm/test/CodeGen/AMDGPU/fsqrt.r600.ll
@@ -228,7 +228,7 @@ define amdgpu_kernel void @recip_sqrt(ptr addrspace(1) %out, float %src) nounwin
; R600-NEXT: LSHR T0.X, KC0[2].Y, literal.x,
; R600-NEXT: RECIPSQRT_IEEE * T1.X, KC0[2].Z,
; R600-NEXT: 2(2.802597e-45), 0(0.000000e+00)
- %sqrt = call float @llvm.sqrt.f32(float %src)
+ %sqrt = call arcp float @llvm.sqrt.f32(float %src)
%recipsqrt = fdiv fast float 1.0, %sqrt
store float %recipsqrt, ptr addrspace(1) %out, align 4
ret void
diff --git a/llvm/test/CodeGen/AMDGPU/rsq.f32.ll b/llvm/test/CodeGen/AMDGPU/rsq.f32.ll
index f967e951b27a4..03e258fc84d61 100644
--- a/llvm/test/CodeGen/AMDGPU/rsq.f32.ll
+++ b/llvm/test/CodeGen/AMDGPU/rsq.f32.ll
@@ -194,8 +194,8 @@ define amdgpu_kernel void @rsqrt_fmul(ptr addrspace(1) %out, ptr addrspace(1) %i
%b = load volatile float, ptr addrspace(1) %gep.1
%c = load volatile float, ptr addrspace(1) %gep.2
- %x = call contract float @llvm.sqrt.f32(float %a)
- %y = fmul contract float %x, %b
+ %x = call arcp contract float @llvm.sqrt.f32(float %a)
+ %y = fmul arcp contract float %x, %b
%z = fdiv arcp afn contract float %c, %y
store float %z, ptr addrspace(1) %out.gep
ret void
@@ -756,7 +756,7 @@ define { float, float } @v_rsq_f32_multi_use(float %val) {
; CI-IEEE-SAFE-NEXT: v_sub_i32_e32 v2, vcc, 0, v2
; CI-IEEE-SAFE-NEXT: v_ldexp_f32_e32 v1, v1, v2
; CI-IEEE-SAFE-NEXT: s_setpc_b64 s[30:31]
- %sqrt = call afn contract float @llvm.sqrt.f32(float %val), !fpmath !1
+ %sqrt = call arcp afn contract float @llvm.sqrt.f32(float %val), !fpmath !1
%insert.0 = insertvalue { float, float } poison, float %sqrt, 0
%div = fdiv arcp afn contract float 1.0, %sqrt, !fpmath !1
%insert.1 = insertvalue { float, float } %insert.0, float %div, 1
@@ -838,7 +838,7 @@ define float @v_rsq_f32_missing_contract0(float %val) {
; CI-IEEE-SAFE-NEXT: v_sub_i32_e32 v0, vcc, 0, v0
; CI-IEEE-SAFE-NEXT: v_ldexp_f32_e32 v0, v1, v0
; CI-IEEE-SAFE-NEXT: s_setpc_b64 s[30:31]
- %sqrt = call afn float @llvm.sqrt.f32(float %val), !fpmath !1
+ %sqrt = call arcp afn float @llvm.sqrt.f32(float %val), !fpmath !1
%div = fdiv arcp afn contract float 1.0, %sqrt, !fpmath !1
ret float %div
}
@@ -855,7 +855,7 @@ define float @v_rsq_f32_missing_contract1(float %val) {
; GCN-IEEE-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GCN-IEEE-NEXT: v_rsq_f32_e32 v0, v0
; GCN-IEEE-NEXT: s_setpc_b64 s[30:31]
- %sqrt = call afn contract float @llvm.sqrt.f32(float %val), !fpmath !1
+ %sqrt = call arcp afn contract float @llvm.sqrt.f32(float %val), !fpmath !1
%div = fdiv arcp afn float 1.0, %sqrt, !fpmath !1
ret float %div
}
diff --git a/llvm/test/CodeGen/NVPTX/sqrt-approx.ll b/llvm/test/CodeGen/NVPTX/sqrt-approx.ll
index 7e4e701af4cd1..59012fd18cb5e 100644
--- a/llvm/test/CodeGen/NVPTX/sqrt-approx.ll
+++ b/llvm/test/CodeGen/NVPTX/sqrt-approx.ll
@@ -53,7 +53,7 @@ define double @test_rsqrt64(double %a) {
; CHECK-NEXT: rsqrt.approx.f64 %rd2, %rd1;
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
; CHECK-NEXT: ret;
- %val = tail call double @llvm.sqrt.f64(double %a)
+ %val = tail call arcp double @llvm.sqrt.f64(double %a)
%ret = fdiv arcp double 1.0, %val
ret double %ret
}
@@ -69,7 +69,7 @@ define double @test_rsqrt64_ftz(double %a) #1 {
; CHECK-NEXT: rsqrt.approx.f64 %rd2, %rd1;
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
; CHECK-NEXT: ret;
- %val = tail call double @llvm.sqrt.f64(double %a)
+ %val = tail call arcp double @llvm.sqrt.f64(double %a)
%ret = fdiv arcp double 1.0, %val
ret double %ret
}
@@ -228,7 +228,7 @@ define float @test_rsqrt32_refined(float %a) #2 {
; CHECK-NEXT: mul.f32 %r6, %r5, %r4;
; CHECK-NEXT: st.param.b32 [func_retval0], %r6;
; CHECK-NEXT: ret;
- %val = tail call float @llvm.sqrt.f32(float %a)
+ %val = tail call arcp float @llvm.sqrt.f32(float %a)
%ret = fdiv arcp contract float 1.0, %val
ret float %ret
}
@@ -283,7 +283,7 @@ define double @test_rsqrt64_refined(double %a) #2 {
; CHECK-NEXT: mul.f64 %rd6, %rd5, %rd4;
; CHECK-NEXT: st.param.b64 [func_retval0], %rd6;
; CHECK-NEXT: ret;
- %val = tail call double @llvm.sqrt.f64(double %a)
+ %val = tail call arcp double @llvm.sqrt.f64(double %a)
%ret = fdiv arcp contract double 1.0, %val
ret double %ret
}
@@ -340,7 +340,7 @@ define float @test_rsqrt32_refined_ftz(float %a) #1 #2 {
; CHECK-NEXT: mul.ftz.f32 %r6, %r5, %r4;
; CHECK-NEXT: st.param.b32 [func_retval0], %r6;
; CHECK-NEXT: ret;
- %val = tail call float @llvm.sqrt.f32(float %a)
+ %val = tail call arcp float @llvm.sqrt.f32(float %a)
%ret = fdiv arcp contract float 1.0, %val
ret float %ret
}
@@ -395,7 +395,7 @@ define double @test_rsqrt64_refined_ftz(double %a) #1 #2 {
; CHECK-NEXT: mul.f64 %rd6, %rd5, %rd4;
; CHECK-NEXT: st.param.b64 [func_retval0], %rd6;
; CHECK-NEXT: ret;
- %val = tail call double @llvm.sqrt.f64(double %a)
+ %val = tail call arcp double @llvm.sqrt.f64(double %a)
%ret = fdiv arcp contract double 1.0, %val
ret double %ret
}
diff --git a/llvm/test/CodeGen/PowerPC/recipest.ll b/llvm/test/CodeGen/PowerPC/recipest.ll
index 4bf572bb02942..55001458d1715 100644
--- a/llvm/test/CodeGen/PowerPC/recipest.ll
+++ b/llvm/test/CodeGen/PowerPC/recipest.ll
@@ -163,8 +163,8 @@ define double @foof_fmf(double %a, float %b) nounwind {
; CHECK-P9-NEXT: xsmulsp f0, f0, f3
; CHECK-P9-NEXT: xsmuldp f1, f1, f0
; CHECK-P9-NEXT: blr
- %x = call contract reassoc arcp float @llvm.sqrt.f32(float %b)
- %y = fpext float %x to double
+ %x = tail call contract reassoc arcp float @llvm.sqrt.f32(float %b)
+ %y = fpext reassoc arcp float %x to double
%r = fdiv contract reassoc arcp double %a, %y
ret double %r
}
@@ -188,7 +188,7 @@ define double @foof_safe(double %a, float %b) nounwind {
; CHECK-P9-NEXT: xsdivdp f1, f1, f0
; CHECK-P9-NEXT: blr
%x = call float @llvm.sqrt.f32(float %b)
- %y = fpext float %x to double
+ %y = fpext arcp float %x to double
%r = fdiv double %a, %y
ret double %r
}
@@ -253,7 +253,7 @@ define float @food_fmf(float %a, double %b) nounwind {
; CHECK-P9-NEXT: xsmulsp f1, f1, f0
; CHECK-P9-NEXT: blr
%x = call contract reassoc arcp double @llvm.sqrt.f64(double %b)
- %y = fptrunc double %x to float
+ %y = fptrunc arcp double %x to float
%r = fdiv contract reassoc arcp float %a, %y
ret float %r
}
@@ -280,7 +280,7 @@ define float @food_safe(float %a, double %b) nounwind {
; CHECK-P9-NEXT: xsdivsp f1, f1, f0
; CHECK-P9-NEXT: blr
%x = call double @llvm.sqrt.f64(double %b)
- %y = fptrunc double %x to float
+ %y = fptrunc arcp double %x to float
%r = fdiv float %a, %y
ret float %r
}
@@ -433,7 +433,7 @@ define float @rsqrt_fmul_fmf(float %a, float %b, float %c) {
; CHECK-P9-NEXT: xsmulsp f1, f3, f4
; CHECK-P9-NEXT: blr
%x = call contract reassoc arcp nsz float @llvm.sqrt.f32(float %a)
- %y = fmul contract reassoc nsz float %x, %b
+ %y = fmul contract reassoc arcp nsz float %x, %b
%z = fdiv contract reassoc arcp nsz ninf float %c, %y
ret float %z
}
diff --git a/llvm/test/CodeGen/PowerPC/vsx-fma-mutate-trivial-copy.ll b/llvm/test/CodeGen/PowerPC/vsx-fma-mutate-trivial-copy.ll
index 539b563691723..96a142d0a5634 100644
--- a/llvm/test/CodeGen/PowerPC/vsx-fma-mutate-trivial-copy.ll
+++ b/llvm/test/CodeGen/PowerPC/vsx-fma-mutate-trivial-copy.ll
@@ -17,7 +17,7 @@ for.body: ; preds = %for.body, %for.body
%div = fdiv reassoc arcp float 0.000000e+00, %W
%add = fadd reassoc float %div, %d_min
%conv2 = fpext float %add to double
- %0 = tail call double @llvm.sqrt.f64(double %conv2)
+ %0 = tail call arcp double @llvm.sqrt.f64(double %conv2)
%div4 = fdiv reassoc arcp double %conv3, %0
%call = tail call signext i32 @p_col_helper(double %div4) #2
br label %for.body
diff --git a/llvm/test/CodeGen/PowerPC/vsx-recip-est.ll b/llvm/test/CodeGen/PowerPC/vsx-recip-est.ll
index 4b9d17c26d012..245c8e06eba36 100644
--- a/llvm/test/CodeGen/PowerPC/vsx-recip-est.ll
+++ b/llvm/test/CodeGen/PowerPC/vsx-recip-est.ll
@@ -23,7 +23,7 @@ entry:
store float %f, ptr %f.addr, align 4
%0 = load float, ptr %f.addr, align 4
%1 = load float, ptr @b, align 4
- %2 = call float @llvm.sqrt.f32(float %1)
+ %2 = call arcp float @llvm.sqrt.f32(float %1)
%div = fdiv arcp float %0, %2
ret float %div
; CHECK-LABEL: @emit_xsrsqrtesp
@@ -51,7 +51,7 @@ entry:
store double %f, ptr %f.addr, align 8
%0 = load double, ptr %f.addr, align 8
%1 = load double, ptr @d, align 8
- %2 = call double @llvm.sqrt.f64(double %1)
+ %2 = call arcp double @llvm.sqrt.f64(double %1)
%div = fdiv arcp double %0, %2
ret double %div
; CHECK-LABEL: @emit_xsrsqrtedp
diff --git a/llvm/test/CodeGen/X86/fmf-flags.ll b/llvm/test/CodeGen/X86/fmf-flags.ll
index 16ebf70126f8b..602ed79b7a919 100644
--- a/llvm/test/CodeGen/X86/fmf-flags.ll
+++ b/llvm/test/CodeGen/X86/fmf-flags.ll
@@ -99,7 +99,7 @@ define dso_local float @not_so_fast_recip_sqrt(float %x) {
; X86-NEXT: fxch %st(1)
; X86-NEXT: fstps sqrt1
; X86-NEXT: retl
- %y = call float @llvm.sqrt.f32(float %x)
+ %y = call arcp float @llvm.sqrt.f32(float %x)
%z = fdiv fast float 1.0, %y
store float %y, ptr @sqrt1, align 4
%ret = fadd float %z , 14.5
diff --git a/llvm/test/CodeGen/X86/sqrt-fastmath-mir.ll b/llvm/test/CodeGen/X86/sqrt-fastmath-mir.ll
index 18588aada145c..b3359126837e1 100644
--- a/llvm/test/CodeGen/X86/sqrt-fastmath-mir.ll
+++ b/llvm/test/CodeGen/X86/sqrt-fastmath-mir.ll
@@ -114,7 +114,7 @@ define float @rsqrt_ieee(float %f) #0 {
; CHECK-NEXT: [[VMULSSrr5:%[0-9]+]]:fr32 = nnan ninf nsz arcp contract afn reassoc nofpexcept VMULSSrr killed [[VMULSSrr4]], killed [[VFMADD213SSr1]], implicit $mxcsr
; CHECK-NEXT: $xmm0 = COPY [[VMULSSrr5]]
; CHECK-NEXT: RET 0, $xmm0
- %sqrt = tail call float @llvm.sqrt.f32(float %f)
+ %sqrt = tail call arcp float @llvm.sqrt.f32(float %f)
%div = fdiv fast float 1.0, %sqrt
ret float %div
}
@@ -139,7 +139,7 @@ define float @rsqrt_daz(float %f) #1 {
; CHECK-NEXT: [[VMULSSrr5:%[0-9]+]]:fr32 = nnan ninf nsz arcp contract afn reassoc nofpexcept VMULSSrr killed [[VMULSSrr4]], killed [[VFMADD213SSr1]], implicit $mxcsr
; CHECK-NEXT: $xmm0 = COPY [[VMULSSrr5]]
; CHECK-NEXT: RET 0, $xmm0
- %sqrt = tail call float @llvm.sqrt.f32(float %f)
+ %sqrt = tail call arcp float @llvm.sqrt.f32(float %f)
%div = fdiv fast float 1.0, %sqrt
ret float %div
}
diff --git a/llvm/test/CodeGen/X86/sqrt-fastmath.ll b/llvm/test/CodeGen/X86/sqrt-fastmath.ll
index 83bfcd7f04d9c..55b0f5d9f8e77 100644
--- a/llvm/test/CodeGen/X86/sqrt-fastmath.ll
+++ b/llvm/test/CodeGen/X86/sqrt-fastmath.ll
@@ -333,7 +333,7 @@ define float @f32_estimate(float %x) #1 {
; AVX512-NEXT: vmulss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
; AVX512-NEXT: vmulss %xmm0, %xmm1, %xmm0
; AVX512-NEXT: retq
- %sqrt = tail call float @llvm.sqrt.f32(float %x)
+ %sqrt = tail call fast float @llvm.sqrt.f32(float %x)
%div = fdiv fast float 1.0, %sqrt
ret float %div
}
@@ -366,7 +366,7 @@ define <4 x float> @v4f32_no_estimate(<4 x float> %x) #0 {
; AVX-NEXT: vbroadcastss {{.*#+}} xmm1 = [1.0E+0,1.0E+0,1.0E+0,1.0E+0]
; AVX-NEXT: vdivps %xmm0, %xmm1, %xmm0
; AVX-NEXT: retq
- %sqrt = tail call <4 x float> @llvm.sqrt.v4f32(<4 x float> %x)
+ %sqrt = tail call fast <4 x float> @llvm.sqrt.v4f32(<4 x float> %x)
%div = fdiv fast <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt
ret <4 x float> %div
}
@@ -402,7 +402,7 @@ define <4 x float> @v4f32_estimate(<4 x float> %x) #1 {
; AVX512-NEXT: vmulps %xmm0, %xmm1, %xmm0
; AVX512-NEXT: vmulps %xmm2, %xmm0, %xmm0
; AVX512-NEXT: retq
- %sqrt = tail call <4 x float> @llvm.sqrt.v4f32(<4 x float> %x)
+ %sqrt = tail call fast <4 x float> @llvm.sqrt.v4f32(<4 x float> %x)
%div = fdiv fast <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt
ret <4 x float> %div
}
@@ -467,7 +467,7 @@ define <8 x float> @v8f32_no_estimate(<8 x float> %x) #0 {
; AVX512-NEXT: vbroadcastss {{.*#+}} ymm1 = [1.0E+0,1.0E+0,1.0E+0,1.0E+0,1.0E+0,1.0E+0,1.0E+0,1.0E+0]
; AVX512-NEXT: vdivps %ymm0, %ymm1, %ymm0
; AVX512-NEXT: retq
- %sqrt = tail call <8 x float> @llvm.sqrt.v8f32(<8 x float> %x)
+ %sqrt = tail call fast <8 x float> @llvm.sqrt.v8f32(<8 x float> %x)
%div = fdiv fast <8 x float> <float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt
ret <8 x float> %div
}
@@ -511,7 +511,7 @@ define <8 x float> @v8f32_estimate(<8 x float> %x) #1 {
; AVX512-NEXT: vmulps %ymm3, %ymm1, %ymm0
; AVX512-NEXT: vmulps %ymm2, %ymm0, %ymm0
; AVX512-NEXT: retq
- %sqrt = tail call <8 x float> @llvm.sqrt.v8f32(<8 x float> %x)
+ %sqrt = tail call fast <8 x float> @llvm.sqrt.v8f32(<8 x float> %x)
%div = fdiv fast <8 x float> <float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt
ret <8 x float> %div
}
@@ -610,7 +610,7 @@ define <16 x float> @v16f32_estimate(<16 x float> %x) #1 {
; AVX512-NEXT: vmulps {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to16}, %zmm1, %zmm1
; AVX512-NEXT: vmulps %zmm0, %zmm1, %zmm0
; AVX512-NEXT: retq
- %sqrt = tail call <16 x float> @llvm.sqrt.v16f32(<16 x float> %x)
+ %sqrt = tail call fast <16 x float> @llvm.sqrt.v16f32(<16 x float> %x)
%div = fdiv fast <16 x float> <float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt
ret <16 x float> %div
}
@@ -656,8 +656,8 @@ define float @div_sqrt_fabs_f32(float %x, float %y, float %z) {
; AVX512-NEXT: vmulss %xmm2, %xmm0, %xmm0
; AVX512-NEXT: vmulss %xmm1, %xmm0, %xmm0
; AVX512-NEXT: retq
- %s = call fast float @llvm.sqrt.f32(float %z)
- %a = call fast float @llvm.fabs.f32(float %y)
+ %s = call arcp reassoc float @llvm.sqrt.f32(float %z)
+ %a = call float @llvm.fabs.f32(float %y)
%m = fmul fast float %s, ...
[truncated]
|
|
This PR has this fix #167574 included -- it should be merged after it |
|
I was working through the legality of the various transformations by hand, and I realized we have no actual specific license for which fast-math flags enable Looking at the history, https://reviews.llvm.org/D47954 was where the code was adjusted to check for
No disagreements here, caveated on the above issue.
The In other words, I agree with your result, again caveated on the first issue mentioned.
Agreed, again caveated on the first issue. |
|
@jcranmer-intel Thanks for the in-depth feedback! I agree that Good point about My goal with this PR was for it to be more about adding missing checks on nodes that are being rewritten, less about what does a fast-math flag mean -- but I guess it's impossible to do that without settling on some kind of definition of a given flag (I do implicitly assume that |
As in title. Rewrite semantics, as defined here: https://llvm.org/docs/LangRef.html#floating-point-semantics, dictate that when a given transformation happens, all of instructions taking part in this transformation need to have appropriate flag present. In the case of this change I understand these semantics as following:
fdiv X, sqrt(Y) -> fmul X, rsqrt(Y)--sqrtneeds to havearcpbecause it's rewritten intorsqrtfdiv X, fpext/fpround(sqrt(Y))-> fmul X, fpext/fpround(rsqt(Y)) --fpext/fproundneeds to havearcpin addition torsqrtbecausefpext/fpround(sqrt)is rewritten forrsqrtfdiv X, (fmul A, sqrt(A)) -> fmul X, rqsrt(A*A*A)--fmulandsqrtneedreassocbecause they get rewritten intosqrt(A*A*A). I don't believe thatfdivneeds to havereassocsince only it's operand gets rewritten inreassocsense.sqrt,fdivandfmulneedarcpto be able to rewrite intorsqrtafterreassocrewrite --fmulneeds to havearcpbecause afterreassocrewrite, newly created node should have an intersection of parents' flags -- withoutarcponfmul, there will be noarcponsqrt, which is needed (we don't create this node, since there's no need, but I think we should understand the semantics as-if it was created).fdiv X, (fmul Y, sqrt(Z)) -> fmul X, (fdiv rsqrt(Z), Y)--fdiv,fmulandsqrt-- all are being rewritten, so all needarcp.