Skip to content

Commit

Permalink
[SVE][CodeGen] Enable reciprocal estimates for scalable fdiv/fsqrt
Browse files Browse the repository at this point in the history
This patch enables the use of reciprocal estimates for SVE
when both the -Ofast and -mrecip flags are used.

Reviewed By: david-arm, paulwalker-arm

Differential Revision: https://reviews.llvm.org/D111657
  • Loading branch information
kmclaughlin-arm committed Oct 25, 2021
1 parent 5fd55b1 commit 1f49b71
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 8 deletions.
22 changes: 18 additions & 4 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -4130,6 +4130,18 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
case Intrinsic::aarch64_sve_frecpx:
return DAG.getNode(AArch64ISD::FRECPX_MERGE_PASSTHRU, dl, Op.getValueType(),
Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
case Intrinsic::aarch64_sve_frecpe_x:
return DAG.getNode(AArch64ISD::FRECPE, dl, Op.getValueType(),
Op.getOperand(1));
case Intrinsic::aarch64_sve_frecps_x:
return DAG.getNode(AArch64ISD::FRECPS, dl, Op.getValueType(),
Op.getOperand(1), Op.getOperand(2));
case Intrinsic::aarch64_sve_frsqrte_x:
return DAG.getNode(AArch64ISD::FRSQRTE, dl, Op.getValueType(),
Op.getOperand(1));
case Intrinsic::aarch64_sve_frsqrts_x:
return DAG.getNode(AArch64ISD::FRSQRTS, dl, Op.getValueType(),
Op.getOperand(1), Op.getOperand(2));
case Intrinsic::aarch64_sve_fabs:
return DAG.getNode(AArch64ISD::FABS_MERGE_PASSTHRU, dl, Op.getValueType(),
Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
Expand Down Expand Up @@ -8235,10 +8247,12 @@ static SDValue getEstimate(const AArch64Subtarget *ST, unsigned Opcode,
SDValue Operand, SelectionDAG &DAG,
int &ExtraSteps) {
EVT VT = Operand.getValueType();
if (ST->hasNEON() &&
(VT == MVT::f64 || VT == MVT::v1f64 || VT == MVT::v2f64 ||
VT == MVT::f32 || VT == MVT::v1f32 ||
VT == MVT::v2f32 || VT == MVT::v4f32)) {
if ((ST->hasNEON() &&
(VT == MVT::f64 || VT == MVT::v1f64 || VT == MVT::v2f64 ||
VT == MVT::f32 || VT == MVT::v1f32 || VT == MVT::v2f32 ||
VT == MVT::v4f32)) ||
(ST->hasSVE() &&
(VT == MVT::nxv8f16 || VT == MVT::nxv4f32 || VT == MVT::nxv2f64))) {
if (ExtraSteps == TargetLoweringBase::ReciprocalEstimate::Unspecified)
// For the reciprocal estimates, convergence is quadratic, so the number
// of digits is doubled after each iteration. In ARMv8, the accuracy of
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Expand Up @@ -402,8 +402,8 @@ let Predicates = [HasSVEorStreamingSVE] in {
defm SMIN_ZPZZ : sve_int_bin_pred_bhsd<AArch64smin_p>;
defm UMIN_ZPZZ : sve_int_bin_pred_bhsd<AArch64umin_p>;

defm FRECPE_ZZ : sve_fp_2op_u_zd<0b110, "frecpe", int_aarch64_sve_frecpe_x>;
defm FRSQRTE_ZZ : sve_fp_2op_u_zd<0b111, "frsqrte", int_aarch64_sve_frsqrte_x>;
defm FRECPE_ZZ : sve_fp_2op_u_zd<0b110, "frecpe", AArch64frecpe>;
defm FRSQRTE_ZZ : sve_fp_2op_u_zd<0b111, "frsqrte", AArch64frsqrte>;

defm FADD_ZPmI : sve_fp_2op_i_p_zds<0b000, "fadd", "FADD_ZPZI", sve_fpimm_half_one, fpimm_half, fpimm_one, int_aarch64_sve_fadd>;
defm FSUB_ZPmI : sve_fp_2op_i_p_zds<0b001, "fsub", "FSUB_ZPZI", sve_fpimm_half_one, fpimm_half, fpimm_one, int_aarch64_sve_fsub>;
Expand Down Expand Up @@ -484,8 +484,8 @@ let Predicates = [HasSVE] in {
} // End HasSVE

let Predicates = [HasSVEorStreamingSVE] in {
defm FRECPS_ZZZ : sve_fp_3op_u_zd<0b110, "frecps", int_aarch64_sve_frecps_x>;
defm FRSQRTS_ZZZ : sve_fp_3op_u_zd<0b111, "frsqrts", int_aarch64_sve_frsqrts_x>;
defm FRECPS_ZZZ : sve_fp_3op_u_zd<0b110, "frecps", AArch64frecps>;
defm FRSQRTS_ZZZ : sve_fp_3op_u_zd<0b111, "frsqrts", AArch64frsqrts>;
} // End HasSVEorStreamingSVE

let Predicates = [HasSVE] in {
Expand Down
179 changes: 179 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-fp-reciprocal.ll
@@ -0,0 +1,179 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s

; FDIV

define <vscale x 8 x half> @fdiv_8f16(<vscale x 8 x half> %a, <vscale x 8 x half> %b) {
; CHECK-LABEL: fdiv_8f16:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: fdiv z0.h, p0/m, z0.h, z1.h
; CHECK-NEXT: ret
%fdiv = fdiv fast <vscale x 8 x half> %a, %b
ret <vscale x 8 x half> %fdiv
}

define <vscale x 8 x half> @fdiv_recip_8f16(<vscale x 8 x half> %a, <vscale x 8 x half> %b) #0 {
; CHECK-LABEL: fdiv_recip_8f16:
; CHECK: // %bb.0:
; CHECK-NEXT: frecpe z2.h, z1.h
; CHECK-NEXT: frecps z3.h, z1.h, z2.h
; CHECK-NEXT: fmul z2.h, z2.h, z3.h
; CHECK-NEXT: frecps z1.h, z1.h, z2.h
; CHECK-NEXT: fmul z1.h, z2.h, z1.h
; CHECK-NEXT: fmul z0.h, z1.h, z0.h
; CHECK-NEXT: ret
%fdiv = fdiv fast <vscale x 8 x half> %a, %b
ret <vscale x 8 x half> %fdiv
}

define <vscale x 4 x float> @fdiv_4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b) {
; CHECK-LABEL: fdiv_4f32:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: fdiv z0.s, p0/m, z0.s, z1.s
; CHECK-NEXT: ret
%fdiv = fdiv fast <vscale x 4 x float> %a, %b
ret <vscale x 4 x float> %fdiv
}

define <vscale x 4 x float> @fdiv_recip_4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b) #0 {
; CHECK-LABEL: fdiv_recip_4f32:
; CHECK: // %bb.0:
; CHECK-NEXT: frecpe z2.s, z1.s
; CHECK-NEXT: frecps z3.s, z1.s, z2.s
; CHECK-NEXT: fmul z2.s, z2.s, z3.s
; CHECK-NEXT: frecps z1.s, z1.s, z2.s
; CHECK-NEXT: fmul z1.s, z2.s, z1.s
; CHECK-NEXT: fmul z0.s, z1.s, z0.s
; CHECK-NEXT: ret
%fdiv = fdiv fast <vscale x 4 x float> %a, %b
ret <vscale x 4 x float> %fdiv
}

define <vscale x 2 x double> @fdiv_2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b) {
; CHECK-LABEL: fdiv_2f64:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: fdiv z0.d, p0/m, z0.d, z1.d
; CHECK-NEXT: ret
%fdiv = fdiv fast <vscale x 2 x double> %a, %b
ret <vscale x 2 x double> %fdiv
}

define <vscale x 2 x double> @fdiv_recip_2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b) #0 {
; CHECK-LABEL: fdiv_recip_2f64:
; CHECK: // %bb.0:
; CHECK-NEXT: frecpe z2.d, z1.d
; CHECK-NEXT: frecps z3.d, z1.d, z2.d
; CHECK-NEXT: fmul z2.d, z2.d, z3.d
; CHECK-NEXT: frecps z3.d, z1.d, z2.d
; CHECK-NEXT: fmul z2.d, z2.d, z3.d
; CHECK-NEXT: frecps z1.d, z1.d, z2.d
; CHECK-NEXT: fmul z1.d, z2.d, z1.d
; CHECK-NEXT: fmul z0.d, z1.d, z0.d
; CHECK-NEXT: ret
%fdiv = fdiv fast <vscale x 2 x double> %a, %b
ret <vscale x 2 x double> %fdiv
}

; FSQRT

define <vscale x 8 x half> @fsqrt_8f16(<vscale x 8 x half> %a) {
; CHECK-LABEL: fsqrt_8f16:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: fsqrt z0.h, p0/m, z0.h
; CHECK-NEXT: ret
%fsqrt = call fast <vscale x 8 x half> @llvm.sqrt.nxv8f16(<vscale x 8 x half> %a)
ret <vscale x 8 x half> %fsqrt
}

define <vscale x 8 x half> @fsqrt_recip_8f16(<vscale x 8 x half> %a) #0 {
; CHECK-LABEL: fsqrt_recip_8f16:
; CHECK: // %bb.0:
; CHECK-NEXT: frsqrte z1.h, z0.h
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: fmul z2.h, z1.h, z1.h
; CHECK-NEXT: fcmeq p0.h, p0/z, z0.h, #0.0
; CHECK-NEXT: frsqrts z2.h, z0.h, z2.h
; CHECK-NEXT: fmul z1.h, z1.h, z2.h
; CHECK-NEXT: fmul z2.h, z1.h, z1.h
; CHECK-NEXT: frsqrts z2.h, z0.h, z2.h
; CHECK-NEXT: fmul z1.h, z1.h, z2.h
; CHECK-NEXT: fmul z1.h, z0.h, z1.h
; CHECK-NEXT: sel z0.h, p0, z0.h, z1.h
; CHECK-NEXT: ret
%fsqrt = call fast <vscale x 8 x half> @llvm.sqrt.nxv8f16(<vscale x 8 x half> %a)
ret <vscale x 8 x half> %fsqrt
}

define <vscale x 4 x float> @fsqrt_4f32(<vscale x 4 x float> %a) {
; CHECK-LABEL: fsqrt_4f32:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: fsqrt z0.s, p0/m, z0.s
; CHECK-NEXT: ret
%fsqrt = call fast <vscale x 4 x float> @llvm.sqrt.nxv4f32(<vscale x 4 x float> %a)
ret <vscale x 4 x float> %fsqrt
}

define <vscale x 4 x float> @fsqrt_recip_4f32(<vscale x 4 x float> %a) #0 {
; CHECK-LABEL: fsqrt_recip_4f32:
; CHECK: // %bb.0:
; CHECK-NEXT: frsqrte z1.s, z0.s
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: fmul z2.s, z1.s, z1.s
; CHECK-NEXT: fcmeq p0.s, p0/z, z0.s, #0.0
; CHECK-NEXT: frsqrts z2.s, z0.s, z2.s
; CHECK-NEXT: fmul z1.s, z1.s, z2.s
; CHECK-NEXT: fmul z2.s, z1.s, z1.s
; CHECK-NEXT: frsqrts z2.s, z0.s, z2.s
; CHECK-NEXT: fmul z1.s, z1.s, z2.s
; CHECK-NEXT: fmul z1.s, z0.s, z1.s
; CHECK-NEXT: sel z0.s, p0, z0.s, z1.s
; CHECK-NEXT: ret
%fsqrt = call fast <vscale x 4 x float> @llvm.sqrt.nxv4f32(<vscale x 4 x float> %a)
ret <vscale x 4 x float> %fsqrt
}

define <vscale x 2 x double> @fsqrt_2f64(<vscale x 2 x double> %a) {
; CHECK-LABEL: fsqrt_2f64:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: fsqrt z0.d, p0/m, z0.d
; CHECK-NEXT: ret
%fsqrt = call fast <vscale x 2 x double> @llvm.sqrt.nxv2f64(<vscale x 2 x double> %a)
ret <vscale x 2 x double> %fsqrt
}

define <vscale x 2 x double> @fsqrt_recip_2f64(<vscale x 2 x double> %a) #0 {
; CHECK-LABEL: fsqrt_recip_2f64:
; CHECK: // %bb.0:
; CHECK-NEXT: frsqrte z1.d, z0.d
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: fmul z2.d, z1.d, z1.d
; CHECK-NEXT: fcmeq p0.d, p0/z, z0.d, #0.0
; CHECK-NEXT: frsqrts z2.d, z0.d, z2.d
; CHECK-NEXT: fmul z1.d, z1.d, z2.d
; CHECK-NEXT: fmul z2.d, z1.d, z1.d
; CHECK-NEXT: frsqrts z2.d, z0.d, z2.d
; CHECK-NEXT: fmul z1.d, z1.d, z2.d
; CHECK-NEXT: fmul z2.d, z1.d, z1.d
; CHECK-NEXT: frsqrts z2.d, z0.d, z2.d
; CHECK-NEXT: fmul z1.d, z1.d, z2.d
; CHECK-NEXT: fmul z1.d, z0.d, z1.d
; CHECK-NEXT: sel z0.d, p0, z0.d, z1.d
; CHECK-NEXT: ret
%fsqrt = call fast <vscale x 2 x double> @llvm.sqrt.nxv2f64(<vscale x 2 x double> %a)
ret <vscale x 2 x double> %fsqrt
}

declare <vscale x 2 x half> @llvm.sqrt.nxv2f16(<vscale x 2 x half>)
declare <vscale x 4 x half> @llvm.sqrt.nxv4f16(<vscale x 4 x half>)
declare <vscale x 8 x half> @llvm.sqrt.nxv8f16(<vscale x 8 x half>)
declare <vscale x 2 x float> @llvm.sqrt.nxv2f32(<vscale x 2 x float>)
declare <vscale x 4 x float> @llvm.sqrt.nxv4f32(<vscale x 4 x float>)
declare <vscale x 2 x double> @llvm.sqrt.nxv2f64(<vscale x 2 x double>)

attributes #0 = { "reciprocal-estimates"="all" }

0 comments on commit 1f49b71

Please sign in to comment.