Skip to content

Commit

Permalink
[ARM][NEON] Improve vector popcnt lowering with PADDL (PR39281)
Browse files Browse the repository at this point in the history
As I suggested on PR39281, this patch uses PADDL pairwise addition to widen from the vXi8 CTPOP result to the target vector type.

This is a blocker for moving more x86 code to generic vector CTPOP expansion (P32655 + D53258) - ARM's vXi64 CTPOP currently expands, which would generate a vXi64 MUL but ARM's custom lowering expands the general MUL case and vectors aren't well handled in LegalizeDAG - improving the CTPOP lowering was a lot easier than fixing the MUL lowering for this one case......

Differential Revision: https://reviews.llvm.org/D53257

llvm-svn: 344512
  • Loading branch information
RKSimon committed Oct 15, 2018
1 parent 10ec5c8 commit 5abb607
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 267 deletions.
156 changes: 26 additions & 130 deletions llvm/lib/Target/ARM/ARMISelLowering.cpp
Expand Up @@ -669,8 +669,8 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::CTPOP, MVT::v4i32, Custom);
setOperationAction(ISD::CTPOP, MVT::v4i16, Custom);
setOperationAction(ISD::CTPOP, MVT::v8i16, Custom);
setOperationAction(ISD::CTPOP, MVT::v1i64, Expand);
setOperationAction(ISD::CTPOP, MVT::v2i64, Expand);
setOperationAction(ISD::CTPOP, MVT::v1i64, Custom);
setOperationAction(ISD::CTPOP, MVT::v2i64, Custom);

setOperationAction(ISD::CTLZ, MVT::v1i64, Expand);
setOperationAction(ISD::CTLZ, MVT::v2i64, Expand);
Expand Down Expand Up @@ -5409,10 +5409,6 @@ static SDValue LowerCTTZ(SDNode *N, SelectionDAG &DAG,

// Compute with: cttz(x) = ctpop(lsb - 1)

// Since we can only compute the number of bits in a byte with vcnt.8, we
// have to gather the result with pairwise addition (vpaddl) for i16, i32,
// and i64.

// Compute LSB - 1.
SDValue Bits;
if (ElemTy == MVT::i64) {
Expand All @@ -5425,32 +5421,7 @@ static SDValue LowerCTTZ(SDNode *N, SelectionDAG &DAG,
DAG.getTargetConstant(1, dl, ElemTy));
Bits = DAG.getNode(ISD::SUB, dl, VT, LSB, One);
}

// Count #bits with vcnt.8.
EVT VT8Bit = VT.is64BitVector() ? MVT::v8i8 : MVT::v16i8;
SDValue BitsVT8 = DAG.getNode(ISD::BITCAST, dl, VT8Bit, Bits);
SDValue Cnt8 = DAG.getNode(ISD::CTPOP, dl, VT8Bit, BitsVT8);

// Gather the #bits with vpaddl (pairwise add.)
EVT VT16Bit = VT.is64BitVector() ? MVT::v4i16 : MVT::v8i16;
SDValue Cnt16 = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT16Bit,
DAG.getTargetConstant(Intrinsic::arm_neon_vpaddlu, dl, MVT::i32),
Cnt8);
if (ElemTy == MVT::i16)
return Cnt16;

EVT VT32Bit = VT.is64BitVector() ? MVT::v2i32 : MVT::v4i32;
SDValue Cnt32 = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT32Bit,
DAG.getTargetConstant(Intrinsic::arm_neon_vpaddlu, dl, MVT::i32),
Cnt16);
if (ElemTy == MVT::i32)
return Cnt32;

assert(ElemTy == MVT::i64);
SDValue Cnt64 = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
DAG.getTargetConstant(Intrinsic::arm_neon_vpaddlu, dl, MVT::i32),
Cnt32);
return Cnt64;
return DAG.getNode(ISD::CTPOP, dl, VT, Bits);
}

if (!ST->hasV6T2Ops())
Expand All @@ -5460,112 +5431,37 @@ static SDValue LowerCTTZ(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ISD::CTLZ, dl, VT, rbit);
}

/// getCTPOP16BitCounts - Returns a v8i8/v16i8 vector containing the bit-count
/// for each 16-bit element from operand, repeated. The basic idea is to
/// leverage vcnt to get the 8-bit counts, gather and add the results.
///
/// Trace for v4i16:
/// input = [v0 v1 v2 v3 ] (vi 16-bit element)
/// cast: N0 = [w0 w1 w2 w3 w4 w5 w6 w7] (v0 = [w0 w1], wi 8-bit element)
/// vcnt: N1 = [b0 b1 b2 b3 b4 b5 b6 b7] (bi = bit-count of 8-bit element wi)
/// vrev: N2 = [b1 b0 b3 b2 b5 b4 b7 b6]
/// [b0 b1 b2 b3 b4 b5 b6 b7]
/// +[b1 b0 b3 b2 b5 b4 b7 b6]
/// N3=N1+N2 = [k0 k0 k1 k1 k2 k2 k3 k3] (k0 = b0+b1 = bit-count of 16-bit v0,
/// vuzp: = [k0 k1 k2 k3 k0 k1 k2 k3] each ki is 8-bits)
static SDValue getCTPOP16BitCounts(SDNode *N, SelectionDAG &DAG) {
EVT VT = N->getValueType(0);
SDLoc DL(N);

EVT VT8Bit = VT.is64BitVector() ? MVT::v8i8 : MVT::v16i8;
SDValue N0 = DAG.getNode(ISD::BITCAST, DL, VT8Bit, N->getOperand(0));
SDValue N1 = DAG.getNode(ISD::CTPOP, DL, VT8Bit, N0);
SDValue N2 = DAG.getNode(ARMISD::VREV16, DL, VT8Bit, N1);
SDValue N3 = DAG.getNode(ISD::ADD, DL, VT8Bit, N1, N2);
return DAG.getNode(ARMISD::VUZP, DL, VT8Bit, N3, N3);
}

/// lowerCTPOP16BitElements - Returns a v4i16/v8i16 vector containing the
/// bit-count for each 16-bit element from the operand. We need slightly
/// different sequencing for v4i16 and v8i16 to stay within NEON's available
/// 64/128-bit registers.
///
/// Trace for v4i16:
/// input = [v0 v1 v2 v3 ] (vi 16-bit element)
/// v8i8: BitCounts = [k0 k1 k2 k3 k0 k1 k2 k3 ] (ki is the bit-count of vi)
/// v8i16:Extended = [k0 k1 k2 k3 k0 k1 k2 k3 ]
/// v4i16:Extracted = [k0 k1 k2 k3 ]
static SDValue lowerCTPOP16BitElements(SDNode *N, SelectionDAG &DAG) {
static SDValue LowerCTPOP(SDNode *N, SelectionDAG &DAG,
const ARMSubtarget *ST) {
EVT VT = N->getValueType(0);
SDLoc DL(N);

SDValue BitCounts = getCTPOP16BitCounts(N, DAG);
if (VT.is64BitVector()) {
SDValue Extended = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, BitCounts);
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v4i16, Extended,
DAG.getIntPtrConstant(0, DL));
} else {
SDValue Extracted = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8,
BitCounts, DAG.getIntPtrConstant(0, DL));
return DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, Extracted);
}
}

/// lowerCTPOP32BitElements - Returns a v2i32/v4i32 vector containing the
/// bit-count for each 32-bit element from the operand. The idea here is
/// to split the vector into 16-bit elements, leverage the 16-bit count
/// routine, and then combine the results.
///
/// Trace for v2i32 (v4i32 similar with Extracted/Extended exchanged):
/// input = [v0 v1 ] (vi: 32-bit elements)
/// Bitcast = [w0 w1 w2 w3 ] (wi: 16-bit elements, v0 = [w0 w1])
/// Counts16 = [k0 k1 k2 k3 ] (ki: 16-bit elements, bit-count of wi)
/// vrev: N0 = [k1 k0 k3 k2 ]
/// [k0 k1 k2 k3 ]
/// N1 =+[k1 k0 k3 k2 ]
/// [k0 k2 k1 k3 ]
/// N2 =+[k1 k3 k0 k2 ]
/// [k0 k2 k1 k3 ]
/// Extended =+[k1 k3 k0 k2 ]
/// [k0 k2 ]
/// Extracted=+[k1 k3 ]
///
static SDValue lowerCTPOP32BitElements(SDNode *N, SelectionDAG &DAG) {
EVT VT = N->getValueType(0);
SDLoc DL(N);
assert(ST->hasNEON() && "Custom ctpop lowering requires NEON.");
assert((VT == MVT::v1i64 || VT == MVT::v2i64 || VT == MVT::v2i32 ||
VT == MVT::v4i32 || VT == MVT::v4i16 || VT == MVT::v8i16) &&
"Unexpected type for custom ctpop lowering");

EVT VT16Bit = VT.is64BitVector() ? MVT::v4i16 : MVT::v8i16;
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
EVT VT8Bit = VT.is64BitVector() ? MVT::v8i8 : MVT::v16i8;
SDValue Res = DAG.getBitcast(VT8Bit, N->getOperand(0));
Res = DAG.getNode(ISD::CTPOP, DL, VT8Bit, Res);

SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, VT16Bit, N->getOperand(0));
SDValue Counts16 = lowerCTPOP16BitElements(Bitcast.getNode(), DAG);
SDValue N0 = DAG.getNode(ARMISD::VREV32, DL, VT16Bit, Counts16);
SDValue N1 = DAG.getNode(ISD::ADD, DL, VT16Bit, Counts16, N0);
SDValue N2 = DAG.getNode(ARMISD::VUZP, DL, VT16Bit, N1, N1);
// Widen v8i8/v16i8 CTPOP result to VT by repeatedly widening pairwise adds.
unsigned EltSize = 8;
unsigned NumElts = VT.is64BitVector() ? 8 : 16;
while (EltSize != VT.getScalarSizeInBits()) {
SmallVector<SDValue, 8> Ops;
Ops.push_back(DAG.getConstant(Intrinsic::arm_neon_vpaddlu, DL,
TLI.getPointerTy(DAG.getDataLayout())));
Ops.push_back(Res);

if (VT.is64BitVector()) {
SDValue Extended = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v4i32, N2);
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v2i32, Extended,
DAG.getIntPtrConstant(0, DL));
} else {
SDValue Extracted = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v4i16, N2,
DAG.getIntPtrConstant(0, DL));
return DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v4i32, Extracted);
EltSize *= 2;
NumElts /= 2;
MVT WidenVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize), NumElts);
Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WidenVT, Ops);
}
}

static SDValue LowerCTPOP(SDNode *N, SelectionDAG &DAG,
const ARMSubtarget *ST) {
EVT VT = N->getValueType(0);

assert(ST->hasNEON() && "Custom ctpop lowering requires NEON.");
assert((VT == MVT::v2i32 || VT == MVT::v4i32 ||
VT == MVT::v4i16 || VT == MVT::v8i16) &&
"Unexpected type for custom ctpop lowering");

if (VT.getVectorElementType() == MVT::i32)
return lowerCTPOP32BitElements(N, DAG);
else
return lowerCTPOP16BitElements(N, DAG);
return Res;
}

static SDValue LowerShift(SDNode *N, SelectionDAG &DAG,
Expand Down
154 changes: 17 additions & 137 deletions llvm/test/CodeGen/ARM/popcnt.ll
Expand Up @@ -32,11 +32,7 @@ define <4 x i16> @vcnt16(<4 x i16>* %A) nounwind {
; CHECK: @ %bb.0:
; CHECK-NEXT: vldr d16, [r0]
; CHECK-NEXT: vcnt.8 d16, d16
; CHECK-NEXT: vrev16.8 d17, d16
; CHECK-NEXT: vadd.i8 d16, d16, d17
; CHECK-NEXT: vorr d17, d16, d16
; CHECK-NEXT: vuzp.8 d16, d17
; CHECK-NEXT: vmovl.u8 q8, d16
; CHECK-NEXT: vpaddl.u8 d16, d16
; CHECK-NEXT: vmov r0, r1, d16
; CHECK-NEXT: mov pc, lr
%tmp1 = load <4 x i16>, <4 x i16>* %A
Expand All @@ -49,11 +45,7 @@ define <8 x i16> @vcntQ16(<8 x i16>* %A) nounwind {
; CHECK: @ %bb.0:
; CHECK-NEXT: vld1.64 {d16, d17}, [r0]
; CHECK-NEXT: vcnt.8 q8, q8
; CHECK-NEXT: vrev16.8 q9, q8
; CHECK-NEXT: vadd.i8 q8, q8, q9
; CHECK-NEXT: vorr q9, q8, q8
; CHECK-NEXT: vuzp.8 q8, q9
; CHECK-NEXT: vmovl.u8 q8, d16
; CHECK-NEXT: vpaddl.u8 q8, q8
; CHECK-NEXT: vmov r0, r1, d16
; CHECK-NEXT: vmov r2, r3, d17
; CHECK-NEXT: mov pc, lr
Expand All @@ -67,16 +59,8 @@ define <2 x i32> @vcnt32(<2 x i32>* %A) nounwind {
; CHECK: @ %bb.0:
; CHECK-NEXT: vldr d16, [r0]
; CHECK-NEXT: vcnt.8 d16, d16
; CHECK-NEXT: vrev16.8 d17, d16
; CHECK-NEXT: vadd.i8 d16, d16, d17
; CHECK-NEXT: vorr d17, d16, d16
; CHECK-NEXT: vuzp.8 d16, d17
; CHECK-NEXT: vmovl.u8 q8, d16
; CHECK-NEXT: vrev32.16 d18, d16
; CHECK-NEXT: vadd.i16 d16, d16, d18
; CHECK-NEXT: vorr d17, d16, d16
; CHECK-NEXT: vuzp.16 d16, d17
; CHECK-NEXT: vmovl.u16 q8, d16
; CHECK-NEXT: vpaddl.u8 d16, d16
; CHECK-NEXT: vpaddl.u16 d16, d16
; CHECK-NEXT: vmov r0, r1, d16
; CHECK-NEXT: mov pc, lr
%tmp1 = load <2 x i32>, <2 x i32>* %A
Expand All @@ -89,16 +73,8 @@ define <4 x i32> @vcntQ32(<4 x i32>* %A) nounwind {
; CHECK: @ %bb.0:
; CHECK-NEXT: vld1.64 {d16, d17}, [r0]
; CHECK-NEXT: vcnt.8 q8, q8
; CHECK-NEXT: vrev16.8 q9, q8
; CHECK-NEXT: vadd.i8 q8, q8, q9
; CHECK-NEXT: vorr q9, q8, q8
; CHECK-NEXT: vuzp.8 q8, q9
; CHECK-NEXT: vmovl.u8 q9, d16
; CHECK-NEXT: vrev32.16 q9, q9
; CHECK-NEXT: vaddw.u8 q8, q9, d16
; CHECK-NEXT: vorr q9, q8, q8
; CHECK-NEXT: vuzp.16 q8, q9
; CHECK-NEXT: vmovl.u16 q8, d16
; CHECK-NEXT: vpaddl.u8 q8, q8
; CHECK-NEXT: vpaddl.u16 q8, q8
; CHECK-NEXT: vmov r0, r1, d16
; CHECK-NEXT: vmov r2, r3, d17
; CHECK-NEXT: mov pc, lr
Expand All @@ -110,50 +86,13 @@ define <4 x i32> @vcntQ32(<4 x i32>* %A) nounwind {
define <1 x i64> @vcnt64(<1 x i64>* %A) nounwind {
; CHECK-LABEL: vcnt64:
; CHECK: @ %bb.0:
; CHECK-NEXT: .save {r4, lr}
; CHECK-NEXT: push {r4, lr}
; CHECK-NEXT: vldr d16, [r0]
; CHECK-NEXT: ldr r2, .LCPI6_0
; CHECK-NEXT: vmov.32 r0, d16[0]
; CHECK-NEXT: ldr r3, .LCPI6_3
; CHECK-NEXT: vmov.32 r1, d16[1]
; CHECK-NEXT: ldr lr, .LCPI6_2
; CHECK-NEXT: ldr r12, .LCPI6_1
; CHECK-NEXT: vldr s1, .LCPI6_4
; CHECK-NEXT: and r4, r2, r0, lsr #1
; CHECK-NEXT: sub r0, r0, r4
; CHECK-NEXT: and r2, r2, r1, lsr #1
; CHECK-NEXT: sub r1, r1, r2
; CHECK-NEXT: and r4, r0, r3
; CHECK-NEXT: and r0, r3, r0, lsr #2
; CHECK-NEXT: and r2, r1, r3
; CHECK-NEXT: add r0, r4, r0
; CHECK-NEXT: and r1, r3, r1, lsr #2
; CHECK-NEXT: add r1, r2, r1
; CHECK-NEXT: add r0, r0, r0, lsr #4
; CHECK-NEXT: and r0, r0, lr
; CHECK-NEXT: add r1, r1, r1, lsr #4
; CHECK-NEXT: mul r2, r0, r12
; CHECK-NEXT: and r0, r1, lr
; CHECK-NEXT: mul r1, r0, r12
; CHECK-NEXT: lsr r0, r2, #24
; CHECK-NEXT: add r0, r0, r1, lsr #24
; CHECK-NEXT: vmov s0, r0
; CHECK-NEXT: vmov r0, r1, d0
; CHECK-NEXT: pop {r4, lr}
; CHECK-NEXT: vcnt.8 d16, d16
; CHECK-NEXT: vpaddl.u8 d16, d16
; CHECK-NEXT: vpaddl.u16 d16, d16
; CHECK-NEXT: vpaddl.u32 d16, d16
; CHECK-NEXT: vmov r0, r1, d16
; CHECK-NEXT: mov pc, lr
; CHECK-NEXT: .p2align 2
; CHECK-NEXT: @ %bb.1:
; CHECK-NEXT: .LCPI6_0:
; CHECK-NEXT: .long 1431655765 @ 0x55555555
; CHECK-NEXT: .LCPI6_1:
; CHECK-NEXT: .long 16843009 @ 0x1010101
; CHECK-NEXT: .LCPI6_2:
; CHECK-NEXT: .long 252645135 @ 0xf0f0f0f
; CHECK-NEXT: .LCPI6_3:
; CHECK-NEXT: .long 858993459 @ 0x33333333
; CHECK-NEXT: .LCPI6_4:
; CHECK-NEXT: .long 0 @ float 0
%tmp1 = load <1 x i64>, <1 x i64>* %A
%tmp2 = call <1 x i64> @llvm.ctpop.v1i64(<1 x i64> %tmp1)
ret <1 x i64> %tmp2
Expand All @@ -162,73 +101,14 @@ define <1 x i64> @vcnt64(<1 x i64>* %A) nounwind {
define <2 x i64> @vcntQ64(<2 x i64>* %A) nounwind {
; CHECK-LABEL: vcntQ64:
; CHECK: @ %bb.0:
; CHECK-NEXT: .save {r4, r5, r6, lr}
; CHECK-NEXT: push {r4, r5, r6, lr}
; CHECK-NEXT: vld1.64 {d16, d17}, [r0]
; CHECK-NEXT: vmov.32 r1, d17[1]
; CHECK-NEXT: ldr lr, .LCPI7_0
; CHECK-NEXT: vmov.32 r2, d17[0]
; CHECK-NEXT: ldr r0, .LCPI7_2
; CHECK-NEXT: vmov.32 r3, d16[0]
; CHECK-NEXT: ldr r12, .LCPI7_1
; CHECK-NEXT: ldr r5, .LCPI7_3
; CHECK-NEXT: vldr s3, .LCPI7_4
; CHECK-NEXT: and r4, lr, r1, lsr #1
; CHECK-NEXT: sub r1, r1, r4
; CHECK-NEXT: and r4, r1, r0
; CHECK-NEXT: and r1, r0, r1, lsr #2
; CHECK-NEXT: add r1, r4, r1
; CHECK-NEXT: and r4, lr, r2, lsr #1
; CHECK-NEXT: sub r2, r2, r4
; CHECK-NEXT: and r4, r2, r0
; CHECK-NEXT: add r1, r1, r1, lsr #4
; CHECK-NEXT: and r2, r0, r2, lsr #2
; CHECK-NEXT: and r6, r1, r12
; CHECK-NEXT: add r2, r4, r2
; CHECK-NEXT: and r4, lr, r3, lsr #1
; CHECK-NEXT: sub r3, r3, r4
; CHECK-NEXT: and r4, r3, r0
; CHECK-NEXT: add r2, r2, r2, lsr #4
; CHECK-NEXT: and r3, r0, r3, lsr #2
; CHECK-NEXT: and r2, r2, r12
; CHECK-NEXT: add r3, r4, r3
; CHECK-NEXT: add r3, r3, r3, lsr #4
; CHECK-NEXT: and r3, r3, r12
; CHECK-NEXT: mul r4, r3, r5
; CHECK-NEXT: vmov.32 r3, d16[1]
; CHECK-NEXT: and r1, lr, r3, lsr #1
; CHECK-NEXT: sub r1, r3, r1
; CHECK-NEXT: and r3, r1, r0
; CHECK-NEXT: and r0, r0, r1, lsr #2
; CHECK-NEXT: mul r1, r2, r5
; CHECK-NEXT: add r0, r3, r0
; CHECK-NEXT: mul r2, r6, r5
; CHECK-NEXT: add r0, r0, r0, lsr #4
; CHECK-NEXT: and r0, r0, r12
; CHECK-NEXT: mul r3, r0, r5
; CHECK-NEXT: lsr r0, r1, #24
; CHECK-NEXT: lsr r1, r4, #24
; CHECK-NEXT: add r0, r0, r2, lsr #24
; CHECK-NEXT: vmov s2, r0
; CHECK-NEXT: add r0, r1, r3, lsr #24
; CHECK-NEXT: vmov s0, r0
; CHECK-NEXT: vmov.f32 s1, s3
; CHECK-NEXT: vmov r2, r3, d1
; CHECK-NEXT: vmov r0, r1, d0
; CHECK-NEXT: pop {r4, r5, r6, lr}
; CHECK-NEXT: vcnt.8 q8, q8
; CHECK-NEXT: vpaddl.u8 q8, q8
; CHECK-NEXT: vpaddl.u16 q8, q8
; CHECK-NEXT: vpaddl.u32 q8, q8
; CHECK-NEXT: vmov r0, r1, d16
; CHECK-NEXT: vmov r2, r3, d17
; CHECK-NEXT: mov pc, lr
; CHECK-NEXT: .p2align 2
; CHECK-NEXT: @ %bb.1:
; CHECK-NEXT: .LCPI7_0:
; CHECK-NEXT: .long 1431655765 @ 0x55555555
; CHECK-NEXT: .LCPI7_1:
; CHECK-NEXT: .long 252645135 @ 0xf0f0f0f
; CHECK-NEXT: .LCPI7_2:
; CHECK-NEXT: .long 858993459 @ 0x33333333
; CHECK-NEXT: .LCPI7_3:
; CHECK-NEXT: .long 16843009 @ 0x1010101
; CHECK-NEXT: .LCPI7_4:
; CHECK-NEXT: .long 0 @ float 0
%tmp1 = load <2 x i64>, <2 x i64>* %A
%tmp2 = call <2 x i64> @llvm.ctpop.v2i64(<2 x i64> %tmp1)
ret <2 x i64> %tmp2
Expand Down

0 comments on commit 5abb607

Please sign in to comment.