diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 30eb19036ddda..790622da2fb79 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1123,7 +1123,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, ISD::UINT_TO_FP}); setTargetDAGCombine({ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT, - ISD::FP_TO_UINT_SAT, ISD::FADD}); + ISD::FP_TO_UINT_SAT, ISD::FADD, ISD::FMA}); // Try and combine setcc with csel setTargetDAGCombine(ISD::SETCC); @@ -28339,6 +28339,71 @@ static SDValue performCTPOPCombine(SDNode *N, return DAG.getNegative(NegPopCount, DL, VT); } +// Combine manual Newton-Raphson reciprocal square root refinement patterns +// into FRSQRTS instructions. +// +// The Newton-Raphson iteration for rsqrt is: +// r' = r * (1.5 - 0.5 * x * r * r) +// +// This appears as: +// fma(r, 1.5, mul(mul(mul(x, -0.5), r), r * r)) +// where r = frsqrte(x) is the initial estimate. +// +// We convert this to use FRSQRTS: r * frsqrts(x * r, r). +static SDValue +performRSQRTRefinementCombine(SDNode *N, SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { + using namespace SDPatternMatch; + + if (!Subtarget->useRSqrt()) + return SDValue(); + + if (N->getOpcode() != ISD::FMA) + return SDValue(); + + auto IsFRSQRTE = [](SDValue V) { + if (V.getOpcode() == AArch64ISD::FRSQRTE) + return true; + if (V.getOpcode() == ISD::INTRINSIC_WO_CHAIN) + return V.getConstantOperandVal(0) == Intrinsic::aarch64_neon_frsqrte; + return false; + }; + + // Match: fma(Est, 1.5, MulChain) where Est = frsqrte(x). + SDValue Est = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + SDValue MulChain = N->getOperand(2); + EVT VT = N->getValueType(0); + + if (!IsFRSQRTE(Est) || + !sd_match(Op1, m_SpecificFP(APFloat(VT.getFltSemantics(), "1.5")))) + return SDValue(); + + // Match: MulChain = (X * -0.5 * Est) * (Est * Est). + SDValue Chain; + if (!sd_match(MulChain, m_FMul(m_FMul(m_Specific(Est), m_Deferred(Est)), + m_Value(Chain)))) + return SDValue(); + + // Match Chain = (X * -0.5) * Est. + SDValue XNegHalf; + if (!sd_match(Chain, m_FMul(m_Specific(Est), m_Value(XNegHalf)))) + return SDValue(); + + // Match XNegHalf = X * -0.5. + SDValue X; + if (!sd_match(XNegHalf, + m_FMul(m_Value(X), + m_SpecificFP(APFloat(VT.getFltSemantics(), "-0.5"))))) + return SDValue(); + + // Build the replacement: Est * frsqrts(X * Est, Est). + SDLoc DL(N); + SDValue XTimesEst = DAG.getNode(ISD::FMUL, DL, VT, X, Est); + SDValue Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, XTimesEst, Est); + return DAG.getNode(ISD::FMUL, DL, VT, Est, Step); +} + SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -28411,6 +28476,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, return performANDCombine(N, DCI); case ISD::FADD: return performFADDCombine(N, DCI); + case ISD::FMA: + return performRSQRTRefinementCombine(N, DAG, Subtarget); case ISD::INTRINSIC_WO_CHAIN: return performIntrinsicCombine(N, DCI, Subtarget); case ISD::ANY_EXTEND: diff --git a/llvm/test/CodeGen/AArch64/aarch64-manual-rsqrt-newton-raphson.ll b/llvm/test/CodeGen/AArch64/aarch64-manual-rsqrt-newton-raphson.ll new file mode 100644 index 0000000000000..2c5ee2e91f8e5 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/aarch64-manual-rsqrt-newton-raphson.ll @@ -0,0 +1,148 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=aarch64-unknown-linux-gnu -mattr=+neon,+use-reciprocal-square-root | FileCheck %s + +; Test that manual Newton-Raphson reciprocal square root refinement patterns +; are recognized and converted to FRSQRTS instructions. + +declare <4 x float> @llvm.aarch64.neon.frsqrte.v4f32(<4 x float>) +declare <2 x float> @llvm.aarch64.neon.frsqrte.v2f32(<2 x float>) +declare <2 x double> @llvm.aarch64.neon.frsqrte.v2f64(<2 x double>) +declare float @llvm.aarch64.neon.frsqrte.f32(float) +declare double @llvm.aarch64.neon.frsqrte.f64(double) +declare <4 x float> @llvm.fma.v4f32(<4 x float>, <4 x float>, <4 x float>) +declare <2 x float> @llvm.fma.v2f32(<2 x float>, <2 x float>, <2 x float>) +declare <2 x double> @llvm.fma.v2f64(<2 x double>, <2 x double>, <2 x double>) +declare float @llvm.fma.f32(float, float, float) +declare double @llvm.fma.f64(double, double, double) + +define <4 x float> @test_fma_pattern(<4 x float> %x) { +; CHECK-LABEL: test_fma_pattern: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: frsqrte v1.4s, v0.4s +; CHECK-NEXT: fmul v0.4s, v0.4s, v1.4s +; CHECK-NEXT: frsqrts v0.4s, v0.4s, v1.4s +; CHECK-NEXT: fmul v0.4s, v1.4s, v0.4s +; CHECK-NEXT: ret +entry: + %rsqrt_est = call <4 x float> @llvm.aarch64.neon.frsqrte.v4f32(<4 x float> %x) + %r_sq = fmul <4 x float> %rsqrt_est, %rsqrt_est + %x_times_neg_half = fmul <4 x float> %x, splat (float -5.000000e-01) + %mul1 = fmul <4 x float> %x_times_neg_half, %rsqrt_est + %mul2 = fmul <4 x float> %mul1, %r_sq + %result = call <4 x float> @llvm.fma.v4f32(<4 x float> %rsqrt_est, <4 x float> splat (float 1.500000e+00), <4 x float> %mul2) + ret <4 x float> %result +} + +define <2 x float> @test_v2f32(<2 x float> %x) { +; CHECK-LABEL: test_v2f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: frsqrte v1.2s, v0.2s +; CHECK-NEXT: fmul v0.2s, v0.2s, v1.2s +; CHECK-NEXT: frsqrts v0.2s, v0.2s, v1.2s +; CHECK-NEXT: fmul v0.2s, v1.2s, v0.2s +; CHECK-NEXT: ret +entry: + %rsqrt_est = call <2 x float> @llvm.aarch64.neon.frsqrte.v2f32(<2 x float> %x) + %r_sq = fmul <2 x float> %rsqrt_est, %rsqrt_est + %x_times_neg_half = fmul <2 x float> %x, splat (float -5.000000e-01) + %mul1 = fmul <2 x float> %x_times_neg_half, %rsqrt_est + %mul2 = fmul <2 x float> %mul1, %r_sq + %result = call <2 x float> @llvm.fma.v2f32(<2 x float> %rsqrt_est, <2 x float> splat (float 1.500000e+00), <2 x float> %mul2) + ret <2 x float> %result +} + +define <2 x double> @test_v2f64(<2 x double> %x) { +; CHECK-LABEL: test_v2f64: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: frsqrte v1.2d, v0.2d +; CHECK-NEXT: fmul v0.2d, v0.2d, v1.2d +; CHECK-NEXT: frsqrts v0.2d, v0.2d, v1.2d +; CHECK-NEXT: fmul v0.2d, v1.2d, v0.2d +; CHECK-NEXT: ret +entry: + %rsqrt_est = call <2 x double> @llvm.aarch64.neon.frsqrte.v2f64(<2 x double> %x) + %r_sq = fmul <2 x double> %rsqrt_est, %rsqrt_est + %x_times_neg_half = fmul <2 x double> %x, splat (double -5.000000e-01) + %mul1 = fmul <2 x double> %x_times_neg_half, %rsqrt_est + %mul2 = fmul <2 x double> %mul1, %r_sq + %result = call <2 x double> @llvm.fma.v2f64(<2 x double> %rsqrt_est, <2 x double> splat (double 1.500000e+00), <2 x double> %mul2) + ret <2 x double> %result +} + +define float @test_scalar_f32(float %x) { +; CHECK-LABEL: test_scalar_f32: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: frsqrte s1, s0 +; CHECK-NEXT: fmul s0, s0, s1 +; CHECK-NEXT: frsqrts s0, s0, s1 +; CHECK-NEXT: fmul s0, s1, s0 +; CHECK-NEXT: ret +entry: + %rsqrt_est = call float @llvm.aarch64.neon.frsqrte.f32(float %x) + %r_sq = fmul float %rsqrt_est, %rsqrt_est + %x_times_neg_half = fmul float %x, -5.000000e-01 + %mul1 = fmul float %x_times_neg_half, %rsqrt_est + %mul2 = fmul float %mul1, %r_sq + %result = call float @llvm.fma.f32(float %rsqrt_est, float 1.500000e+00, float %mul2) + ret float %result +} + +define double @test_scalar_f64(double %x) { +; CHECK-LABEL: test_scalar_f64: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: frsqrte d1, d0 +; CHECK-NEXT: fmul d0, d0, d1 +; CHECK-NEXT: frsqrts d0, d0, d1 +; CHECK-NEXT: fmul d0, d1, d0 +; CHECK-NEXT: ret +entry: + %rsqrt_est = call double @llvm.aarch64.neon.frsqrte.f64(double %x) + %r_sq = fmul double %rsqrt_est, %rsqrt_est + %x_times_neg_half = fmul double %x, -5.000000e-01 + %mul1 = fmul double %x_times_neg_half, %rsqrt_est + %mul2 = fmul double %mul1, %r_sq + %result = call double @llvm.fma.f64(double %rsqrt_est, double 1.500000e+00, double %mul2) + ret double %result +} + +define <4 x float> @test_different_constants(<4 x float> %x) { +; CHECK-LABEL: test_different_constants: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: frsqrte v2.4s, v0.4s +; CHECK-NEXT: fmov v1.4s, #-0.75000000 +; CHECK-NEXT: fmul v0.4s, v0.4s, v1.4s +; CHECK-NEXT: fmul v1.4s, v2.4s, v2.4s +; CHECK-NEXT: fmul v0.4s, v0.4s, v2.4s +; CHECK-NEXT: fmul v0.4s, v0.4s, v1.4s +; CHECK-NEXT: fmov v1.4s, #1.50000000 +; CHECK-NEXT: fmla v0.4s, v1.4s, v2.4s +; CHECK-NEXT: ret +entry: + %rsqrt_est = call <4 x float> @llvm.aarch64.neon.frsqrte.v4f32(<4 x float> %x) + %r_sq = fmul <4 x float> %rsqrt_est, %rsqrt_est + %x_times_wrong = fmul <4 x float> %x, splat (float -7.500000e-01) + %mul1 = fmul <4 x float> %x_times_wrong, %rsqrt_est + %mul2 = fmul <4 x float> %mul1, %r_sq + %result = call <4 x float> @llvm.fma.v4f32(<4 x float> %rsqrt_est, <4 x float> splat (float 1.500000e+00), <4 x float> %mul2) + ret <4 x float> %result +} + +define <4 x float> @test_non_frsqrte_estimate(<4 x float> %x, <4 x float> %est) { +; CHECK-LABEL: test_non_frsqrte_estimate: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: movi v2.4s, #191, lsl #24 +; CHECK-NEXT: fmul v3.4s, v1.4s, v1.4s +; CHECK-NEXT: fmul v0.4s, v0.4s, v2.4s +; CHECK-NEXT: fmov v2.4s, #1.50000000 +; CHECK-NEXT: fmul v0.4s, v0.4s, v1.4s +; CHECK-NEXT: fmul v0.4s, v0.4s, v3.4s +; CHECK-NEXT: fmla v0.4s, v2.4s, v1.4s +; CHECK-NEXT: ret +entry: + %r_sq = fmul <4 x float> %est, %est + %x_times_neg_half = fmul <4 x float> %x, splat (float -5.000000e-01) + %mul1 = fmul <4 x float> %x_times_neg_half, %est + %mul2 = fmul <4 x float> %mul1, %r_sq + %result = call <4 x float> @llvm.fma.v4f32(<4 x float> %est, <4 x float> splat (float 1.500000e+00), <4 x float> %mul2) + ret <4 x float> %result +}