diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp index fd7356bd39fbfd..8e3f00cc77db42 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -16285,38 +16285,8 @@ static SDValue PerformVECREDUCE_ADDCombine(SDNode *N, SelectionDAG &DAG, SDValue(Node.getNode(), 1)); }; - if (SDValue A = IsVADDV(MVT::i32, ISD::SIGN_EXTEND, {MVT::v8i16, MVT::v16i8})) - return DAG.getNode(ARMISD::VADDVs, dl, ResVT, A); - if (SDValue A = IsVADDV(MVT::i32, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v16i8})) - return DAG.getNode(ARMISD::VADDVu, dl, ResVT, A); - if (SDValue A = IsVADDV(MVT::i64, ISD::SIGN_EXTEND, {MVT::v4i32})) - return Create64bitNode(ARMISD::VADDLVs, {A}); - if (SDValue A = IsVADDV(MVT::i64, ISD::ZERO_EXTEND, {MVT::v4i32})) - return Create64bitNode(ARMISD::VADDLVu, {A}); - if (SDValue A = IsVADDV(MVT::i16, ISD::SIGN_EXTEND, {MVT::v16i8})) - return DAG.getNode(ISD::TRUNCATE, dl, ResVT, - DAG.getNode(ARMISD::VADDVs, dl, MVT::i32, A)); - if (SDValue A = IsVADDV(MVT::i16, ISD::ZERO_EXTEND, {MVT::v16i8})) - return DAG.getNode(ISD::TRUNCATE, dl, ResVT, - DAG.getNode(ARMISD::VADDVu, dl, MVT::i32, A)); - - SDValue Mask; - if (SDValue A = IsPredVADDV(MVT::i32, ISD::SIGN_EXTEND, {MVT::v8i16, MVT::v16i8}, Mask)) - return DAG.getNode(ARMISD::VADDVps, dl, ResVT, A, Mask); - if (SDValue A = IsPredVADDV(MVT::i32, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v16i8}, Mask)) - return DAG.getNode(ARMISD::VADDVpu, dl, ResVT, A, Mask); - if (SDValue A = IsPredVADDV(MVT::i64, ISD::SIGN_EXTEND, {MVT::v4i32}, Mask)) - return Create64bitNode(ARMISD::VADDLVps, {A, Mask}); - if (SDValue A = IsPredVADDV(MVT::i64, ISD::ZERO_EXTEND, {MVT::v4i32}, Mask)) - return Create64bitNode(ARMISD::VADDLVpu, {A, Mask}); - if (SDValue A = IsPredVADDV(MVT::i16, ISD::SIGN_EXTEND, {MVT::v16i8}, Mask)) - return DAG.getNode(ISD::TRUNCATE, dl, ResVT, - DAG.getNode(ARMISD::VADDVps, dl, MVT::i32, A, Mask)); - if (SDValue A = IsPredVADDV(MVT::i16, ISD::ZERO_EXTEND, {MVT::v16i8}, Mask)) - return DAG.getNode(ISD::TRUNCATE, dl, ResVT, - DAG.getNode(ARMISD::VADDVpu, dl, MVT::i32, A, Mask)); - SDValue A, B; + SDValue Mask; if (IsVMLAV(MVT::i32, ISD::SIGN_EXTEND, {MVT::v8i16, MVT::v16i8}, A, B)) return DAG.getNode(ARMISD::VMLAVs, dl, ResVT, A, B); if (IsVMLAV(MVT::i32, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v16i8}, A, B)) @@ -16353,6 +16323,36 @@ static SDValue PerformVECREDUCE_ADDCombine(SDNode *N, SelectionDAG &DAG, return DAG.getNode(ISD::TRUNCATE, dl, ResVT, DAG.getNode(ARMISD::VMLAVpu, dl, MVT::i32, A, B, Mask)); + if (SDValue A = IsVADDV(MVT::i32, ISD::SIGN_EXTEND, {MVT::v8i16, MVT::v16i8})) + return DAG.getNode(ARMISD::VADDVs, dl, ResVT, A); + if (SDValue A = IsVADDV(MVT::i32, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v16i8})) + return DAG.getNode(ARMISD::VADDVu, dl, ResVT, A); + if (SDValue A = IsVADDV(MVT::i64, ISD::SIGN_EXTEND, {MVT::v4i32})) + return Create64bitNode(ARMISD::VADDLVs, {A}); + if (SDValue A = IsVADDV(MVT::i64, ISD::ZERO_EXTEND, {MVT::v4i32})) + return Create64bitNode(ARMISD::VADDLVu, {A}); + if (SDValue A = IsVADDV(MVT::i16, ISD::SIGN_EXTEND, {MVT::v16i8})) + return DAG.getNode(ISD::TRUNCATE, dl, ResVT, + DAG.getNode(ARMISD::VADDVs, dl, MVT::i32, A)); + if (SDValue A = IsVADDV(MVT::i16, ISD::ZERO_EXTEND, {MVT::v16i8})) + return DAG.getNode(ISD::TRUNCATE, dl, ResVT, + DAG.getNode(ARMISD::VADDVu, dl, MVT::i32, A)); + + if (SDValue A = IsPredVADDV(MVT::i32, ISD::SIGN_EXTEND, {MVT::v8i16, MVT::v16i8}, Mask)) + return DAG.getNode(ARMISD::VADDVps, dl, ResVT, A, Mask); + if (SDValue A = IsPredVADDV(MVT::i32, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v16i8}, Mask)) + return DAG.getNode(ARMISD::VADDVpu, dl, ResVT, A, Mask); + if (SDValue A = IsPredVADDV(MVT::i64, ISD::SIGN_EXTEND, {MVT::v4i32}, Mask)) + return Create64bitNode(ARMISD::VADDLVps, {A, Mask}); + if (SDValue A = IsPredVADDV(MVT::i64, ISD::ZERO_EXTEND, {MVT::v4i32}, Mask)) + return Create64bitNode(ARMISD::VADDLVpu, {A, Mask}); + if (SDValue A = IsPredVADDV(MVT::i16, ISD::SIGN_EXTEND, {MVT::v16i8}, Mask)) + return DAG.getNode(ISD::TRUNCATE, dl, ResVT, + DAG.getNode(ARMISD::VADDVps, dl, MVT::i32, A, Mask)); + if (SDValue A = IsPredVADDV(MVT::i16, ISD::ZERO_EXTEND, {MVT::v16i8}, Mask)) + return DAG.getNode(ISD::TRUNCATE, dl, ResVT, + DAG.getNode(ARMISD::VADDVpu, dl, MVT::i32, A, Mask)); + // Some complications. We can get a case where the two inputs of the mul are // the same, then the output sext will have been helpfully converted to a // zext. Turn it back. diff --git a/llvm/test/CodeGen/Thumb2/mve-vecreduce-mla.ll b/llvm/test/CodeGen/Thumb2/mve-vecreduce-mla.ll index b80a3525ab994c..0ce5474c85198d 100644 --- a/llvm/test/CodeGen/Thumb2/mve-vecreduce-mla.ll +++ b/llvm/test/CodeGen/Thumb2/mve-vecreduce-mla.ll @@ -776,8 +776,7 @@ define arm_aapcs_vfpcc i64 @add_v4i8i16_v4i32_v4i64_zext(<4 x i8> %x, <4 x i16> ; CHECK-NEXT: vmov.i32 q2, #0xff ; CHECK-NEXT: vmovlb.u16 q1, q1 ; CHECK-NEXT: vand q0, q0, q2 -; CHECK-NEXT: vmul.i32 q0, q0, q1 -; CHECK-NEXT: vaddlv.u32 r0, r1, q0 +; CHECK-NEXT: vmlalv.u32 r0, r1, q0, q1 ; CHECK-NEXT: bx lr entry: %xx = zext <4 x i8> %x to <4 x i32> @@ -794,8 +793,7 @@ define arm_aapcs_vfpcc i64 @add_v4i8i16_v4i32_v4i64_sext(<4 x i8> %x, <4 x i16> ; CHECK-NEXT: vmovlb.s8 q0, q0 ; CHECK-NEXT: vmovlb.s16 q1, q1 ; CHECK-NEXT: vmovlb.s16 q0, q0 -; CHECK-NEXT: vmul.i32 q0, q0, q1 -; CHECK-NEXT: vaddlv.s32 r0, r1, q0 +; CHECK-NEXT: vmlalv.s32 r0, r1, q0, q1 ; CHECK-NEXT: bx lr entry: %xx = sext <4 x i8> %x to <4 x i32> diff --git a/llvm/test/CodeGen/Thumb2/mve-vecreduce-mlapred.ll b/llvm/test/CodeGen/Thumb2/mve-vecreduce-mlapred.ll index 8cac1710c6b6d6..0aeff64fffe8d5 100644 --- a/llvm/test/CodeGen/Thumb2/mve-vecreduce-mlapred.ll +++ b/llvm/test/CodeGen/Thumb2/mve-vecreduce-mlapred.ll @@ -1530,10 +1530,9 @@ define arm_aapcs_vfpcc i64 @add_v4i8i16_v4i32_v4i64_zext(<4 x i8> %x, <4 x i16> ; CHECK-NEXT: vmov.i32 q3, #0xff ; CHECK-NEXT: vmovlb.u16 q1, q1 ; CHECK-NEXT: vand q0, q0, q3 -; CHECK-NEXT: vmul.i32 q0, q0, q1 -; CHECK-NEXT: vand q1, q2, q3 -; CHECK-NEXT: vpt.i32 eq, q1, zr -; CHECK-NEXT: vaddlvt.u32 r0, r1, q0 +; CHECK-NEXT: vand q2, q2, q3 +; CHECK-NEXT: vpt.i32 eq, q2, zr +; CHECK-NEXT: vmlalvt.u32 r0, r1, q0, q1 ; CHECK-NEXT: bx lr entry: %c = icmp eq <4 x i8> %b, zeroinitializer @@ -1550,13 +1549,12 @@ define arm_aapcs_vfpcc i64 @add_v4i8i16_v4i32_v4i64_sext(<4 x i8> %x, <4 x i16> ; CHECK-LABEL: add_v4i8i16_v4i32_v4i64_sext: ; CHECK: @ %bb.0: @ %entry ; CHECK-NEXT: vmovlb.s8 q0, q0 +; CHECK-NEXT: vmov.i32 q3, #0xff +; CHECK-NEXT: vand q2, q2, q3 ; CHECK-NEXT: vmovlb.s16 q1, q1 ; CHECK-NEXT: vmovlb.s16 q0, q0 -; CHECK-NEXT: vmul.i32 q0, q0, q1 -; CHECK-NEXT: vmov.i32 q1, #0xff -; CHECK-NEXT: vand q1, q2, q1 -; CHECK-NEXT: vpt.i32 eq, q1, zr -; CHECK-NEXT: vaddlvt.s32 r0, r1, q0 +; CHECK-NEXT: vpt.i32 eq, q2, zr +; CHECK-NEXT: vmlalvt.s32 r0, r1, q0, q1 ; CHECK-NEXT: bx lr entry: %c = icmp eq <4 x i8> %b, zeroinitializer