diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index b6a5925123f13..94cd5f2d97cf7 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -2820,6 +2820,23 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) { return SDValue(); } +// Attempt to form avgflooru(A, B) from (A & B) + ((A ^ B) >> 1) +static SDValue combineFixedwidthToAVGFLOORU(SDNode *N, SelectionDAG &DAG) { + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + SDValue N0 = N->getOperand(0); + EVT VT = N0.getValueType(); + SDLoc DL(N); + if (TLI.isOperationLegal(ISD::AVGFLOORU, VT)) { + SDValue A, B; + if (sd_match(N, m_Add(m_And(m_Value(A), m_Value(B)), + m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)), + m_SpecificInt(1))))) { + return DAG.getNode(ISD::AVGFLOORU, DL, VT, A, B); + } + } + return SDValue(); +} + SDValue DAGCombiner::visitADD(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -2835,6 +2852,10 @@ SDValue DAGCombiner::visitADD(SDNode *N) { if (SDValue V = foldAddSubOfSignBit(N, DAG)) return V; + // Try to match AVGFLOORU fixedwidth pattern + if (SDValue V = combineFixedwidthToAVGFLOORU(N, DAG)) + return V; + // fold (a+b) -> (a|b) iff a and b share no bits. if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) && DAG.haveNoCommonBitsSet(N0, N1)) diff --git a/llvm/test/CodeGen/AArch64/hadd-combine.ll b/llvm/test/CodeGen/AArch64/hadd-combine.ll index 2269d75cdbb9e..b035ba03529cc 100644 --- a/llvm/test/CodeGen/AArch64/hadd-combine.ll +++ b/llvm/test/CodeGen/AArch64/hadd-combine.ll @@ -859,6 +859,18 @@ define <4 x i32> @urhadd_v4i32(<4 x i32> %x) { ret <4 x i32> %r } +define <8 x i16> @uhadd_fixedwidth_v4i32(<8 x i16> %a0, <8 x i16> %a1) { +; CHECK-LABEL: uhadd_fixedwidth_v4i32: +; CHECK: // %bb.0: +; CHECK-NEXT: uhadd v0.8h, v0.8h, v1.8h +; CHECK-NEXT: ret + %and = and <8 x i16> %a0, %a1 + %xor = xor <8 x i16> %a0, %a1 + %srl = lshr <8 x i16> %xor, + %res = add <8 x i16> %and, %srl + ret <8 x i16> %res +} + declare <8 x i8> @llvm.aarch64.neon.shadd.v8i8(<8 x i8>, <8 x i8>) declare <4 x i16> @llvm.aarch64.neon.shadd.v4i16(<4 x i16>, <4 x i16>) declare <2 x i32> @llvm.aarch64.neon.shadd.v2i32(<2 x i32>, <2 x i32>)