diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 82066145478391e..b961e5a30cd0f9d 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -145,6 +145,7 @@ static bool isMergePassthruOpcode(unsigned Opc) { case AArch64ISD::FROUND_MERGE_PASSTHRU: case AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU: case AArch64ISD::FTRUNC_MERGE_PASSTHRU: + case AArch64ISD::FSQRT_MERGE_PASSTHRU: return true; } } @@ -990,6 +991,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::FROUND, VT, Custom); setOperationAction(ISD::FROUNDEVEN, VT, Custom); setOperationAction(ISD::FTRUNC, VT, Custom); + setOperationAction(ISD::FSQRT, VT, Custom); } } @@ -1502,6 +1504,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::FROUND_MERGE_PASSTHRU) MAKE_CASE(AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU) MAKE_CASE(AArch64ISD::FTRUNC_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::FSQRT_MERGE_PASSTHRU) MAKE_CASE(AArch64ISD::SETCC_MERGE_ZERO) MAKE_CASE(AArch64ISD::ADC) MAKE_CASE(AArch64ISD::SBC) @@ -3385,6 +3388,9 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case Intrinsic::aarch64_sve_frintz: return DAG.getNode(AArch64ISD::FTRUNC_MERGE_PASSTHRU, dl, Op.getValueType(), Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); + case Intrinsic::aarch64_sve_fsqrt: + return DAG.getNode(AArch64ISD::FSQRT_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); case Intrinsic::aarch64_sve_convert_to_svbool: { EVT OutVT = Op.getValueType(); EVT InVT = Op.getOperand(1).getValueType(); @@ -3696,6 +3702,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU); case ISD::FTRUNC: return LowerToPredicatedOp(Op, DAG, AArch64ISD::FTRUNC_MERGE_PASSTHRU); + case ISD::FSQRT: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSQRT_MERGE_PASSTHRU); case ISD::FP_ROUND: case ISD::STRICT_FP_ROUND: return LowerFP_ROUND(Op, DAG); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index d6e511891752a71..e34caacd272d162 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -102,6 +102,7 @@ enum NodeType : unsigned { FRINT_MERGE_PASSTHRU, FROUND_MERGE_PASSTHRU, FROUNDEVEN_MERGE_PASSTHRU, + FSQRT_MERGE_PASSTHRU, FTRUNC_MERGE_PASSTHRU, SIGN_EXTEND_INREG_MERGE_PASSTHRU, ZERO_EXTEND_INREG_MERGE_PASSTHRU, diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index e01a34242a8d7fe..63545d30b2d11aa 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -209,6 +209,7 @@ def AArch64frintx_mt : SDNode<"AArch64ISD::FRINT_MERGE_PASSTHRU", SDT_AArch64Ari def AArch64frinta_mt : SDNode<"AArch64ISD::FROUND_MERGE_PASSTHRU", SDT_AArch64Arith>; def AArch64frintn_mt : SDNode<"AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU", SDT_AArch64Arith>; def AArch64frintz_mt : SDNode<"AArch64ISD::FTRUNC_MERGE_PASSTHRU", SDT_AArch64Arith>; +def AArch64fsqrt_mt : SDNode<"AArch64ISD::FSQRT_MERGE_PASSTHRU", SDT_AArch64Arith>; def SDT_AArch64ReduceWithInit : SDTypeProfile<1, 3, [SDTCisVec<1>, SDTCisVec<3>]>; def AArch64clasta_n : SDNode<"AArch64ISD::CLASTA_N", SDT_AArch64ReduceWithInit>; @@ -1430,7 +1431,7 @@ multiclass sve_prefetch; defm FRINTI_ZPmZ : sve_fp_2op_p_zd_HSD<0b00111, "frinti", null_frag, AArch64frinti_mt>; defm FRECPX_ZPmZ : sve_fp_2op_p_zd_HSD<0b01100, "frecpx", int_aarch64_sve_frecpx>; - defm FSQRT_ZPmZ : sve_fp_2op_p_zd_HSD<0b01101, "fsqrt", int_aarch64_sve_fsqrt>; + defm FSQRT_ZPmZ : sve_fp_2op_p_zd_HSD<0b01101, "fsqrt", null_frag, AArch64fsqrt_mt>; let Predicates = [HasBF16, HasSVE] in { defm BFDOT_ZZZ : sve_bfloat_dot<"bfdot", int_aarch64_sve_bfdot>; diff --git a/llvm/test/CodeGen/AArch64/sve-fp.ll b/llvm/test/CodeGen/AArch64/sve-fp.ll index e4aea2847bc4cd3..5334e66b22f7e31 100644 --- a/llvm/test/CodeGen/AArch64/sve-fp.ll +++ b/llvm/test/CodeGen/AArch64/sve-fp.ll @@ -480,6 +480,68 @@ define void @float_copy(* %P1, * %P2) { ret void } +; FSQRT + +define @fsqrt_nxv8f16( %a) { +; CHECK-LABEL: fsqrt_nxv8f16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: fsqrt z0.h, p0/m, z0.h +; CHECK-NEXT: ret + %res = call @llvm.sqrt.nxv8f16( %a) + ret %res +} + +define @fsqrt_nxv4f16( %a) { +; CHECK-LABEL: fsqrt_nxv4f16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fsqrt z0.h, p0/m, z0.h +; CHECK-NEXT: ret + %res = call @llvm.sqrt.nxv4f16( %a) + ret %res +} + +define @fsqrt_nxv2f16( %a) { +; CHECK-LABEL: fsqrt_nxv2f16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fsqrt z0.h, p0/m, z0.h +; CHECK-NEXT: ret + %res = call @llvm.sqrt.nxv2f16( %a) + ret %res +} + +define @fsqrt_nxv4f32( %a) { +; CHECK-LABEL: fsqrt_nxv4f32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fsqrt z0.s, p0/m, z0.s +; CHECK-NEXT: ret + %res = call @llvm.sqrt.nxv4f32( %a) + ret %res +} + +define @fsqrt_nxv2f32( %a) { +; CHECK-LABEL: fsqrt_nxv2f32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fsqrt z0.s, p0/m, z0.s +; CHECK-NEXT: ret + %res = call @llvm.sqrt.nxv2f32( %a) + ret %res +} + +define @fsqrt_nxv2f64( %a) { +; CHECK-LABEL: fsqrt_nxv2f64: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fsqrt z0.d, p0/m, z0.d +; CHECK-NEXT: ret + %res = call @llvm.sqrt.nxv2f64( %a) + ret %res +} + declare @llvm.aarch64.sve.frecps.x.nxv8f16(, ) declare @llvm.aarch64.sve.frecps.x.nxv4f32( , ) declare @llvm.aarch64.sve.frecps.x.nxv2f64(, ) @@ -495,5 +557,12 @@ declare @llvm.fma.nxv8f16(, @llvm.fma.nxv4f16(, , ) declare @llvm.fma.nxv2f16(, , ) +declare @llvm.sqrt.nxv8f16( ) +declare @llvm.sqrt.nxv4f16( ) +declare @llvm.sqrt.nxv2f16( ) +declare @llvm.sqrt.nxv4f32() +declare @llvm.sqrt.nxv2f32() +declare @llvm.sqrt.nxv2f64() + ; Function Attrs: nounwind readnone declare double @llvm.aarch64.sve.faddv.nxv2f64(, ) #2