diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp index 5f95acf35c889..8e9ff53985ecf 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -16704,14 +16704,16 @@ static SDValue PerformFAddVSelectCombine(SDNode *N, SelectionDAG &DAG, EVT VT = N->getValueType(0); SDLoc DL(N); - // The identity element for a fadd is -0.0, which these VMOV's represent. - auto isNegativeZeroSplat = [&](SDValue Op) { + // The identity element for a fadd is -0.0 or +0.0 when the nsz flag is set, + // which these VMOV's represent. + auto isIdentitySplat = [&](SDValue Op, bool NSZ) { if (Op.getOpcode() != ISD::BITCAST || Op.getOperand(0).getOpcode() != ARMISD::VMOVIMM) return false; - if (VT == MVT::v4f32 && Op.getOperand(0).getConstantOperandVal(0) == 1664) + uint64_t ImmVal = Op.getOperand(0).getConstantOperandVal(0); + if (VT == MVT::v4f32 && (ImmVal == 1664 || (ImmVal == 0 && NSZ))) return true; - if (VT == MVT::v8f16 && Op.getOperand(0).getConstantOperandVal(0) == 2688) + if (VT == MVT::v8f16 && (ImmVal == 2688 || (ImmVal == 0 && NSZ))) return true; return false; }; @@ -16719,12 +16721,17 @@ static SDValue PerformFAddVSelectCombine(SDNode *N, SelectionDAG &DAG, if (Op0.getOpcode() == ISD::VSELECT && Op1.getOpcode() != ISD::VSELECT) std::swap(Op0, Op1); - if (Op1.getOpcode() != ISD::VSELECT || - !isNegativeZeroSplat(Op1.getOperand(2))) + if (Op1.getOpcode() != ISD::VSELECT) return SDValue(); + + SDNodeFlags FaddFlags = N->getFlags(); + bool NSZ = FaddFlags.hasNoSignedZeros(); + if (!isIdentitySplat(Op1.getOperand(2), NSZ)) + return SDValue(); + SDValue FAdd = - DAG.getNode(ISD::FADD, DL, VT, Op0, Op1.getOperand(1), N->getFlags()); - return DAG.getNode(ISD::VSELECT, DL, VT, Op1.getOperand(0), FAdd, Op0); + DAG.getNode(ISD::FADD, DL, VT, Op0, Op1.getOperand(1), FaddFlags); + return DAG.getNode(ISD::VSELECT, DL, VT, Op1.getOperand(0), FAdd, Op0, FaddFlags); } /// PerformVDIVCombine - VCVT (fixed-point to floating-point, Advanced SIMD) diff --git a/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll b/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll index e3e23f6524ba0..0773b65b5dfe0 100644 --- a/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll +++ b/llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll @@ -363,6 +363,36 @@ entry: ret <4 x float> %b } +define arm_aapcs_vfpcc <4 x float> @fadd_v4f32_x2(<4 x float> %x, <4 x float> %y, i32 %n) { +; CHECK-LABEL: fadd_v4f32_x2: +; CHECK: @ %bb.0: @ %entry +; CHECK-NEXT: vmov.i32 q2, #0x0 +; CHECK-NEXT: vctp.32 r0 +; CHECK-NEXT: vpst +; CHECK-NEXT: vmovt q2, q1 +; CHECK-NEXT: vadd.f32 q0, q2, q0 +; CHECK-NEXT: bx lr +entry: + %c = call <4 x i1> @llvm.arm.mve.vctp32(i32 %n) + %a = select <4 x i1> %c, <4 x float> %y, <4 x float> + %b = fadd <4 x float> %a, %x + ret <4 x float> %b +} + +define arm_aapcs_vfpcc <4 x float> @fadd_v4f32_x3(<4 x float> %x, <4 x float> %y, i32 %n) { +; CHECK-LABEL: fadd_v4f32_x3: +; CHECK: @ %bb.0: @ %entry +; CHECK-NEXT: vctp.32 r0 +; CHECK-NEXT: vpst +; CHECK-NEXT: vaddt.f32 q0, q0, q1 +; CHECK-NEXT: bx lr +entry: + %c = call <4 x i1> @llvm.arm.mve.vctp32(i32 %n) + %a = select <4 x i1> %c, <4 x float> %y, <4 x float> + %b = fadd nsz <4 x float> %a, %x + ret <4 x float> %b +} + define arm_aapcs_vfpcc <8 x half> @fadd_v8f16_x(<8 x half> %x, <8 x half> %y, i32 %n) { ; CHECK-LABEL: fadd_v8f16_x: ; CHECK: @ %bb.0: @ %entry @@ -377,6 +407,36 @@ entry: ret <8 x half> %b } +define arm_aapcs_vfpcc <8 x half> @fadd_v8f16_x2(<8 x half> %x, <8 x half> %y, i32 %n) { +; CHECK-LABEL: fadd_v8f16_x2: +; CHECK: @ %bb.0: @ %entry +; CHECK-NEXT: vmov.i32 q2, #0x0 +; CHECK-NEXT: vctp.16 r0 +; CHECK-NEXT: vpst +; CHECK-NEXT: vmovt q2, q1 +; CHECK-NEXT: vadd.f16 q0, q2, q0 +; CHECK-NEXT: bx lr +entry: + %c = call <8 x i1> @llvm.arm.mve.vctp16(i32 %n) + %a = select <8 x i1> %c, <8 x half> %y, <8 x half> + %b = fadd <8 x half> %a, %x + ret <8 x half> %b +} + +define arm_aapcs_vfpcc <8 x half> @fadd_v8f16_x3(<8 x half> %x, <8 x half> %y, i32 %n) { +; CHECK-LABEL: fadd_v8f16_x3: +; CHECK: @ %bb.0: @ %entry +; CHECK-NEXT: vctp.16 r0 +; CHECK-NEXT: vpst +; CHECK-NEXT: vaddt.f16 q0, q0, q1 +; CHECK-NEXT: bx lr +entry: + %c = call <8 x i1> @llvm.arm.mve.vctp16(i32 %n) + %a = select <8 x i1> %c, <8 x half> %y, <8 x half> + %b = fadd nsz <8 x half> %a, %x + ret <8 x half> %b +} + define arm_aapcs_vfpcc <4 x float> @fsub_v4f32_x(<4 x float> %x, <4 x float> %y, i32 %n) { ; CHECK-LABEL: fsub_v4f32_x: ; CHECK: @ %bb.0: @ %entry