diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp index 40815692c71df..f07c0f022b7f0 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -943,6 +943,7 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM, setTargetDAGCombine(ISD::INTRINSIC_W_CHAIN); setTargetDAGCombine(ISD::INTRINSIC_VOID); setTargetDAGCombine(ISD::VECREDUCE_ADD); + setTargetDAGCombine(ISD::ADD); } if (!Subtarget->hasFP64()) { @@ -1656,6 +1657,10 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const { case ARMISD::VMULLu: return "ARMISD::VMULLu"; case ARMISD::VADDVs: return "ARMISD::VADDVs"; case ARMISD::VADDVu: return "ARMISD::VADDVu"; + case ARMISD::VADDLVs: return "ARMISD::VADDLVs"; + case ARMISD::VADDLVu: return "ARMISD::VADDLVu"; + case ARMISD::VADDLVAs: return "ARMISD::VADDLVAs"; + case ARMISD::VADDLVAu: return "ARMISD::VADDLVAu"; case ARMISD::UMAAL: return "ARMISD::UMAAL"; case ARMISD::UMLAL: return "ARMISD::UMLAL"; case ARMISD::SMLAL: return "ARMISD::SMLAL"; @@ -11731,6 +11736,53 @@ static SDValue PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, return SDValue(); } +static SDValue PerformADDVecReduce(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + const ARMSubtarget *Subtarget) { + if (!Subtarget->hasMVEIntegerOps() || N->getValueType(0) != MVT::i64) + return SDValue(); + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + // We are looking for a i64 add of a VADDLVx. Due to these being i64's, this + // will look like: + // t1: i32,i32 = ARMISD::VADDLVs x + // t2: i64 = build_pair t1, t1:1 + // t3: i64 = add t2, y + // We also need to check for sext / zext and commutitive adds. + auto MakeVecReduce = [&](unsigned Opcode, unsigned OpcodeA, SDValue NA, + SDValue NB) { + if (NB->getOpcode() != ISD::BUILD_PAIR) + return SDValue(); + SDValue VecRed = NB->getOperand(0); + if (VecRed->getOpcode() != Opcode || VecRed.getResNo() != 0 || + NB->getOperand(1) != SDValue(VecRed.getNode(), 1)) + return SDValue(); + + SDLoc dl(N); + SDValue Lo = DCI.DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, NA, + DCI.DAG.getConstant(0, dl, MVT::i32)); + SDValue Hi = DCI.DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, NA, + DCI.DAG.getConstant(1, dl, MVT::i32)); + SDValue Red = + DCI.DAG.getNode(OpcodeA, dl, DCI.DAG.getVTList({MVT::i32, MVT::i32}), + Lo, Hi, VecRed->getOperand(0)); + return DCI.DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, Red, + SDValue(Red.getNode(), 1)); + }; + + if (SDValue M = MakeVecReduce(ARMISD::VADDLVs, ARMISD::VADDLVAs, N0, N1)) + return M; + if (SDValue M = MakeVecReduce(ARMISD::VADDLVu, ARMISD::VADDLVAu, N0, N1)) + return M; + if (SDValue M = MakeVecReduce(ARMISD::VADDLVs, ARMISD::VADDLVAs, N1, N0)) + return M; + if (SDValue M = MakeVecReduce(ARMISD::VADDLVu, ARMISD::VADDLVAu, N1, N0)) + return M; + return SDValue(); +} + bool ARMTargetLowering::isDesirableToCommuteWithShift(const SDNode *N, CombineLevel Level) const { @@ -11902,6 +11954,9 @@ static SDValue PerformADDCombine(SDNode *N, if (SDValue Result = PerformSHLSimplify(N, DCI, Subtarget)) return Result; + if (SDValue Result = PerformADDVecReduce(N, DCI, Subtarget)) + return Result; + // First try with the default operand order. if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI, Subtarget)) return Result; @@ -13945,6 +14000,7 @@ static SDValue PerformVECREDUCE_ADDCombine(SDNode *N, SelectionDAG &DAG, // Cases: // VADDV u/s 8/16/32 + // VADDLV u/s 32 auto IsVADDV = [&](MVT RetTy, unsigned ExtendCode, ArrayRef ExtTypes) { if (ResVT != RetTy || N0->getOpcode() != ExtendCode) @@ -13954,11 +14010,19 @@ static SDValue PerformVECREDUCE_ADDCombine(SDNode *N, SelectionDAG &DAG, return A; return SDValue(); }; + auto Create64bitNode = [&](unsigned Opcode, ArrayRef Ops) { + SDValue Node = DAG.getNode(Opcode, dl, {MVT::i32, MVT::i32}, Ops); + return DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, Node, SDValue(Node.getNode(), 1)); + }; if (SDValue A = IsVADDV(MVT::i32, ISD::SIGN_EXTEND, {MVT::v8i16, MVT::v16i8})) return DAG.getNode(ARMISD::VADDVs, dl, ResVT, A); if (SDValue A = IsVADDV(MVT::i32, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v16i8})) return DAG.getNode(ARMISD::VADDVu, dl, ResVT, A); + if (SDValue A = IsVADDV(MVT::i64, ISD::SIGN_EXTEND, {MVT::v4i32})) + return Create64bitNode(ARMISD::VADDLVs, {A}); + if (SDValue A = IsVADDV(MVT::i64, ISD::ZERO_EXTEND, {MVT::v4i32})) + return Create64bitNode(ARMISD::VADDLVu, {A}); return SDValue(); } diff --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h index a12d7299ff8eb..c635622ee8d40 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.h +++ b/llvm/lib/Target/ARM/ARMISelLowering.h @@ -209,6 +209,10 @@ class VectorType; // MVE reductions VADDVs, VADDVu, + VADDLVs, + VADDLVu, + VADDLVAs, + VADDLVAu, SMULWB, // Signed multiply word by half word, bottom SMULWT, // Signed multiply word by half word, top diff --git a/llvm/lib/Target/ARM/ARMInstrMVE.td b/llvm/lib/Target/ARM/ARMInstrMVE.td index 2b0704fd28f2f..e4c73f824f943 100644 --- a/llvm/lib/Target/ARM/ARMInstrMVE.td +++ b/llvm/lib/Target/ARM/ARMInstrMVE.td @@ -691,6 +691,30 @@ multiclass MVE_VADDLV_A pattern=[]> { defm MVE_VADDLVs32 : MVE_VADDLV_A<"s32", 0b0>; defm MVE_VADDLVu32 : MVE_VADDLV_A<"u32", 0b1>; +def SDTVecReduceL : SDTypeProfile<2, 1, [ // VADDLV + SDTCisInt<0>, SDTCisInt<1>, SDTCisVec<2> +]>; +def SDTVecReduceLA : SDTypeProfile<2, 3, [ // VADDLVA + SDTCisInt<0>, SDTCisInt<1>, SDTCisInt<2>, SDTCisInt<3>, + SDTCisVec<4> +]>; +def ARMVADDLVs : SDNode<"ARMISD::VADDLVs", SDTVecReduceL>; +def ARMVADDLVu : SDNode<"ARMISD::VADDLVu", SDTVecReduceL>; +def ARMVADDLVAs : SDNode<"ARMISD::VADDLVAs", SDTVecReduceLA>; +def ARMVADDLVAu : SDNode<"ARMISD::VADDLVAu", SDTVecReduceLA>; + +let Predicates = [HasMVEInt] in { + def : Pat<(ARMVADDLVs (v4i32 MQPR:$val1)), + (MVE_VADDLVs32no_acc (v4i32 MQPR:$val1))>; + def : Pat<(ARMVADDLVu (v4i32 MQPR:$val1)), + (MVE_VADDLVu32no_acc (v4i32 MQPR:$val1))>; + + def : Pat<(ARMVADDLVAs tGPREven:$Rda, tGPROdd:$Rdb, (v4i32 MQPR:$val1)), + (MVE_VADDLVs32acc tGPREven:$Rda, tGPROdd:$Rdb, (v4i32 MQPR:$val1))>; + def : Pat<(ARMVADDLVAu tGPREven:$Rda, tGPROdd:$Rdb, (v4i32 MQPR:$val1)), + (MVE_VADDLVu32acc tGPREven:$Rda, tGPROdd:$Rdb, (v4i32 MQPR:$val1))>; +} + class MVE_VMINMAXNMV pattern=[]> : MVE_rDest<(outs rGPR:$RdaDest), (ins rGPR:$RdaSrc, MQPR:$Qm), diff --git a/llvm/test/CodeGen/Thumb2/mve-vecreduce-add.ll b/llvm/test/CodeGen/Thumb2/mve-vecreduce-add.ll index 4ada1a65512e1..ced01f0606c7d 100644 --- a/llvm/test/CodeGen/Thumb2/mve-vecreduce-add.ll +++ b/llvm/test/CodeGen/Thumb2/mve-vecreduce-add.ll @@ -14,36 +14,8 @@ entry: define arm_aapcs_vfpcc i64 @add_v4i32_v4i64_zext(<4 x i32> %x) { ; CHECK-LABEL: add_v4i32_v4i64_zext: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: adr r0, .LCPI1_0 -; CHECK-NEXT: vmov.f32 s4, s0 -; CHECK-NEXT: vldrw.u32 q2, [r0] -; CHECK-NEXT: vmov.f32 s6, s1 -; CHECK-NEXT: vand q1, q1, q2 -; CHECK-NEXT: vmov r2, s6 -; CHECK-NEXT: vmov r3, s4 -; CHECK-NEXT: vmov r0, s7 -; CHECK-NEXT: vmov r1, s5 -; CHECK-NEXT: vmov.f32 s4, s2 -; CHECK-NEXT: vmov.f32 s6, s3 -; CHECK-NEXT: vand q0, q1, q2 -; CHECK-NEXT: adds r2, r2, r3 -; CHECK-NEXT: vmov r3, s0 -; CHECK-NEXT: adcs r0, r1 -; CHECK-NEXT: vmov r1, s1 -; CHECK-NEXT: adds r2, r2, r3 -; CHECK-NEXT: vmov r3, s3 -; CHECK-NEXT: adcs r1, r0 -; CHECK-NEXT: vmov r0, s2 -; CHECK-NEXT: adds r0, r0, r2 -; CHECK-NEXT: adcs r1, r3 +; CHECK-NEXT: vaddlv.u32 r0, r1, q0 ; CHECK-NEXT: bx lr -; CHECK-NEXT: .p2align 4 -; CHECK-NEXT: @ %bb.1: -; CHECK-NEXT: .LCPI1_0: -; CHECK-NEXT: .long 4294967295 @ 0xffffffff -; CHECK-NEXT: .long 0 @ 0x0 -; CHECK-NEXT: .long 4294967295 @ 0xffffffff -; CHECK-NEXT: .long 0 @ 0x0 entry: %xx = zext <4 x i32> %x to <4 x i64> %z = call i64 @llvm.experimental.vector.reduce.add.v4i64(<4 x i64> %xx) @@ -53,29 +25,7 @@ entry: define arm_aapcs_vfpcc i64 @add_v4i32_v4i64_sext(<4 x i32> %x) { ; CHECK-LABEL: add_v4i32_v4i64_sext: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmov.f32 s4, s0 -; CHECK-NEXT: vmov.f32 s6, s1 -; CHECK-NEXT: vmov r0, s4 -; CHECK-NEXT: vmov.32 q2[0], r0 -; CHECK-NEXT: asrs r0, r0, #31 -; CHECK-NEXT: vmov.32 q2[1], r0 -; CHECK-NEXT: vmov r0, s6 -; CHECK-NEXT: vmov.32 q2[2], r0 -; CHECK-NEXT: vmov.f32 s4, s2 -; CHECK-NEXT: vmov.f32 s6, s3 -; CHECK-NEXT: asrs r1, r0, #31 -; CHECK-NEXT: vmov.32 q2[3], r1 -; CHECK-NEXT: vmov r2, s10 -; CHECK-NEXT: vmov r3, s8 -; CHECK-NEXT: vmov r1, s9 -; CHECK-NEXT: adds r2, r2, r3 -; CHECK-NEXT: vmov r3, s6 -; CHECK-NEXT: adc.w r0, r1, r0, asr #31 -; CHECK-NEXT: vmov r1, s4 -; CHECK-NEXT: adds r2, r2, r1 -; CHECK-NEXT: adc.w r1, r0, r1, asr #31 -; CHECK-NEXT: adds r0, r2, r3 -; CHECK-NEXT: adc.w r1, r1, r3, asr #31 +; CHECK-NEXT: vaddlv.s32 r0, r1, q0 ; CHECK-NEXT: bx lr entry: %xx = sext <4 x i32> %x to <4 x i64> @@ -856,40 +806,8 @@ entry: define arm_aapcs_vfpcc i64 @add_v4i32_v4i64_acc_zext(<4 x i32> %x, i64 %a) { ; CHECK-LABEL: add_v4i32_v4i64_acc_zext: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: .save {r4, lr} -; CHECK-NEXT: push {r4, lr} -; CHECK-NEXT: adr r2, .LCPI29_0 -; CHECK-NEXT: vmov.f32 s4, s0 -; CHECK-NEXT: vldrw.u32 q2, [r2] -; CHECK-NEXT: vmov.f32 s6, s1 -; CHECK-NEXT: vand q1, q1, q2 -; CHECK-NEXT: vmov r2, s6 -; CHECK-NEXT: vmov r3, s4 -; CHECK-NEXT: vmov r12, s7 -; CHECK-NEXT: vmov lr, s5 -; CHECK-NEXT: vmov.f32 s4, s2 -; CHECK-NEXT: vmov.f32 s6, s3 -; CHECK-NEXT: vand q0, q1, q2 -; CHECK-NEXT: adds r4, r3, r2 -; CHECK-NEXT: vmov r3, s0 -; CHECK-NEXT: vmov r2, s1 -; CHECK-NEXT: adc.w r12, r12, lr -; CHECK-NEXT: adds r3, r3, r4 -; CHECK-NEXT: vmov r4, s3 -; CHECK-NEXT: adc.w r12, r12, r2 -; CHECK-NEXT: vmov r2, s2 -; CHECK-NEXT: adds r2, r2, r3 -; CHECK-NEXT: adc.w r3, r12, r4 -; CHECK-NEXT: adds r0, r0, r2 -; CHECK-NEXT: adcs r1, r3 -; CHECK-NEXT: pop {r4, pc} -; CHECK-NEXT: .p2align 4 -; CHECK-NEXT: @ %bb.1: -; CHECK-NEXT: .LCPI29_0: -; CHECK-NEXT: .long 4294967295 @ 0xffffffff -; CHECK-NEXT: .long 0 @ 0x0 -; CHECK-NEXT: .long 4294967295 @ 0xffffffff -; CHECK-NEXT: .long 0 @ 0x0 +; CHECK-NEXT: vaddlva.u32 r0, r1, q0 +; CHECK-NEXT: bx lr entry: %xx = zext <4 x i32> %x to <4 x i64> %z = call i64 @llvm.experimental.vector.reduce.add.v4i64(<4 x i64> %xx) @@ -900,34 +818,8 @@ entry: define arm_aapcs_vfpcc i64 @add_v4i32_v4i64_acc_sext(<4 x i32> %x, i64 %a) { ; CHECK-LABEL: add_v4i32_v4i64_acc_sext: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: .save {r7, lr} -; CHECK-NEXT: push {r7, lr} -; CHECK-NEXT: vmov.f32 s4, s0 -; CHECK-NEXT: vmov.f32 s6, s1 -; CHECK-NEXT: vmov r2, s4 -; CHECK-NEXT: vmov.32 q2[0], r2 -; CHECK-NEXT: asrs r2, r2, #31 -; CHECK-NEXT: vmov.32 q2[1], r2 -; CHECK-NEXT: vmov r2, s6 -; CHECK-NEXT: vmov.32 q2[2], r2 -; CHECK-NEXT: vmov.f32 s4, s2 -; CHECK-NEXT: vmov.f32 s6, s3 -; CHECK-NEXT: asrs r3, r2, #31 -; CHECK-NEXT: vmov.32 q2[3], r3 -; CHECK-NEXT: vmov lr, s10 -; CHECK-NEXT: vmov r3, s8 -; CHECK-NEXT: vmov r12, s9 -; CHECK-NEXT: adds.w r3, r3, lr -; CHECK-NEXT: adc.w r12, r12, r2, asr #31 -; CHECK-NEXT: vmov r2, s4 -; CHECK-NEXT: adds r3, r3, r2 -; CHECK-NEXT: adc.w r12, r12, r2, asr #31 -; CHECK-NEXT: vmov r2, s6 -; CHECK-NEXT: adds r3, r3, r2 -; CHECK-NEXT: adc.w r2, r12, r2, asr #31 -; CHECK-NEXT: adds r0, r0, r3 -; CHECK-NEXT: adcs r1, r2 -; CHECK-NEXT: pop {r7, pc} +; CHECK-NEXT: vaddlva.s32 r0, r1, q0 +; CHECK-NEXT: bx lr entry: %xx = sext <4 x i32> %x to <4 x i64> %z = call i64 @llvm.experimental.vector.reduce.add.v4i64(<4 x i64> %xx)