diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 60aa61e993b26..b6342220a3159 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -50,6 +50,7 @@ #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineMemOperand.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/SDPatternMatch.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/CodeGen/TargetCallingConv.h" @@ -104,7 +105,6 @@ #include using namespace llvm; -using namespace llvm::PatternMatch; #define DEBUG_TYPE "aarch64-lower" @@ -1180,6 +1180,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setTargetDAGCombine(ISD::SHL); setTargetDAGCombine(ISD::VECTOR_DEINTERLEAVE); + setTargetDAGCombine(ISD::CTPOP); // In case of strict alignment, avoid an excessive number of byte wide stores. MaxStoresPerMemsetOptSize = 8; @@ -17591,6 +17592,7 @@ bool AArch64TargetLowering::optimizeExtendOrTruncateConversion( // udot instruction. if (SrcWidth * 4 <= DstWidth) { if (all_of(I->users(), [&](auto *U) { + using namespace llvm::PatternMatch; auto *SingleUser = cast(&*U); if (match(SingleUser, m_c_Mul(m_Specific(I), m_SExt(m_Value())))) return true; @@ -17862,6 +17864,7 @@ bool AArch64TargetLowering::lowerInterleavedLoad( // into shift / and masks. For the moment we do this just for uitofp (not // zext) to avoid issues with widening instructions. if (Shuffles.size() == 4 && all_of(Shuffles, [](ShuffleVectorInst *SI) { + using namespace llvm::PatternMatch; return SI->hasOneUse() && match(SI->user_back(), m_UIToFP(m_Value())) && SI->getType()->getScalarSizeInBits() * 4 == SI->user_back()->getType()->getScalarSizeInBits(); @@ -27878,6 +27881,35 @@ static SDValue performRNDRCombine(SDNode *N, SelectionDAG &DAG) { {A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL); } +static SDValue performCTPOPCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + using namespace llvm::SDPatternMatch; + if (!DCI.isBeforeLegalize()) + return SDValue(); + + // ctpop(zext(bitcast(vector_mask))) -> neg(signed_reduce_add(vector_mask)) + SDValue Mask; + if (!sd_match(N->getOperand(0), m_ZExt(m_BitCast(m_Value(Mask))))) + return SDValue(); + + EVT VT = N->getValueType(0); + EVT MaskVT = Mask.getValueType(); + + if (VT.isVector() || !MaskVT.isFixedLengthVector() || + MaskVT.getVectorElementType() != MVT::i1) + return SDValue(); + + EVT ReduceInVT = + EVT::getVectorVT(*DAG.getContext(), VT, MaskVT.getVectorElementCount()); + + SDLoc DL(N); + // Sign extend to best fit ZeroOrNegativeOneBooleanContent. + SDValue ExtMask = DAG.getNode(ISD::SIGN_EXTEND, DL, ReduceInVT, Mask); + SDValue NegPopCount = DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, ExtMask); + return DAG.getNegative(NegPopCount, DL, VT); +} + SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -28223,6 +28255,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, return performScalarToVectorCombine(N, DCI, DAG); case ISD::SHL: return performSHLCombine(N, DCI, DAG); + case ISD::CTPOP: + return performCTPOPCombine(N, DCI, DAG); } return SDValue(); } diff --git a/llvm/test/CodeGen/AArch64/popcount_vmask.ll b/llvm/test/CodeGen/AArch64/popcount_vmask.ll new file mode 100644 index 0000000000000..e784ead2c9e5a --- /dev/null +++ b/llvm/test/CodeGen/AArch64/popcount_vmask.ll @@ -0,0 +1,315 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s | FileCheck %s + +target triple = "aarch64-unknown-linux-gnu" + +define i32 @vmask_popcount_i32_v8i8(<8 x i8> %a, <8 x i8> %b) { +; CHECK-LABEL: vmask_popcount_i32_v8i8: +; CHECK: // %bb.0: +; CHECK-NEXT: cmgt v0.8b, v1.8b, v0.8b +; CHECK-NEXT: sshll v0.8h, v0.8b, #0 +; CHECK-NEXT: saddlv s0, v0.8h +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: neg w0, w8 +; CHECK-NEXT: ret + %mask = icmp slt <8 x i8> %a, %b + %t1 = bitcast <8 x i1> %mask to i8 + %t2 = call i8 @llvm.ctpop(i8 %t1) + %t3 = zext i8 %t2 to i32 + ret i32 %t3 +} + +define i32 @vmask_popcount_i32_v16i8(<16 x i8> %a, <16 x i8> %b) { +; CHECK-LABEL: vmask_popcount_i32_v16i8: +; CHECK: // %bb.0: +; CHECK-NEXT: cmgt v0.16b, v1.16b, v0.16b +; CHECK-NEXT: sshll2 v1.8h, v0.16b, #0 +; CHECK-NEXT: sshll v0.8h, v0.8b, #0 +; CHECK-NEXT: saddl2 v2.4s, v0.8h, v1.8h +; CHECK-NEXT: saddl v0.4s, v0.4h, v1.4h +; CHECK-NEXT: add v0.4s, v0.4s, v2.4s +; CHECK-NEXT: addv s0, v0.4s +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: neg w0, w8 +; CHECK-NEXT: ret + %mask = icmp slt <16 x i8> %a, %b + %t1 = bitcast <16 x i1> %mask to i16 + %t2 = call i16 @llvm.ctpop(i16 %t1) + %t3 = zext i16 %t2 to i32 + ret i32 %t3 +} + +define i32 @vmask_popcount_i32_v4i16(<4 x i16> %a, <4 x i16> %b) { +; CHECK-LABEL: vmask_popcount_i32_v4i16: +; CHECK: // %bb.0: +; CHECK-NEXT: cmgt v0.4h, v1.4h, v0.4h +; CHECK-NEXT: saddlv s0, v0.4h +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: neg w0, w8 +; CHECK-NEXT: ret + %mask = icmp slt <4 x i16> %a, %b + %t1 = bitcast <4 x i1> %mask to i4 + %t2 = call i4 @llvm.ctpop(i4 %t1) + %t3 = zext i4 %t2 to i32 + ret i32 %t3 +} + +define i32 @vmask_popcount_i32_v8i16(<8 x i16> %a, <8 x i16> %b) { +; CHECK-LABEL: vmask_popcount_i32_v8i16: +; CHECK: // %bb.0: +; CHECK-NEXT: cmgt v0.8h, v1.8h, v0.8h +; CHECK-NEXT: saddlv s0, v0.8h +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: neg w0, w8 +; CHECK-NEXT: ret + %mask = icmp slt <8 x i16> %a, %b + %t1 = bitcast <8 x i1> %mask to i8 + %t2 = call i8 @llvm.ctpop(i8 %t1) + %t3 = zext i8 %t2 to i32 + ret i32 %t3 +} + +define i32 @vmask_popcount_i32_v2i32(<2 x i32> %a, <2 x i32> %b) { +; CHECK-LABEL: vmask_popcount_i32_v2i32: +; CHECK: // %bb.0: +; CHECK-NEXT: cmgt v0.2s, v1.2s, v0.2s +; CHECK-NEXT: addp v0.2s, v0.2s, v0.2s +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: neg w0, w8 +; CHECK-NEXT: ret + %mask = icmp slt <2 x i32> %a, %b + %t1 = bitcast <2 x i1> %mask to i2 + %t2 = call i2 @llvm.ctpop(i2 %t1) + %t3 = zext i2 %t2 to i32 + ret i32 %t3 +} + +define i32 @vmask_popcount_i32_v4i32(<4 x i32> %a, <4 x i32> %b) { +; CHECK-LABEL: vmask_popcount_i32_v4i32: +; CHECK: // %bb.0: +; CHECK-NEXT: cmgt v0.4s, v1.4s, v0.4s +; CHECK-NEXT: addv s0, v0.4s +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: neg w0, w8 +; CHECK-NEXT: ret + %mask = icmp slt <4 x i32> %a, %b + %t1 = bitcast <4 x i1> %mask to i4 + %t2 = call i4 @llvm.ctpop(i4 %t1) + %t3 = zext i4 %t2 to i32 + ret i32 %t3 +} + +define i32 @vmask_popcount_i32_v1i64(<1 x i64> %a, <1 x i64> %b) { +; CHECK-LABEL: vmask_popcount_i32_v1i64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1 +; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0 +; CHECK-NEXT: fmov x8, d1 +; CHECK-NEXT: fmov x9, d0 +; CHECK-NEXT: cmp x9, x8 +; CHECK-NEXT: cset w0, lt +; CHECK-NEXT: ret + %mask = icmp slt <1 x i64> %a, %b + %t1 = bitcast <1 x i1> %mask to i1 + %t2 = call i1 @llvm.ctpop(i1 %t1) + %t3 = zext i1 %t2 to i32 + ret i32 %t3 +} + +define i32 @vmask_popcount_i32_v2i64(<2 x i64> %a, <2 x i64> %b) { +; CHECK-LABEL: vmask_popcount_i32_v2i64: +; CHECK: // %bb.0: +; CHECK-NEXT: cmgt v0.2d, v1.2d, v0.2d +; CHECK-NEXT: xtn v0.2s, v0.2d +; CHECK-NEXT: addp v0.2s, v0.2s, v0.2s +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: neg w0, w8 +; CHECK-NEXT: ret + %mask = icmp slt <2 x i64> %a, %b + %t1 = bitcast <2 x i1> %mask to i2 + %t2 = call i2 @llvm.ctpop(i2 %t1) + %t3 = zext i2 %t2 to i32 + ret i32 %t3 +} + +define i64 @vmask_popcount_i64_v8i8(<8 x i8> %a, <8 x i8> %b) { +; CHECK-LABEL: vmask_popcount_i64_v8i8: +; CHECK: // %bb.0: +; CHECK-NEXT: cmgt v0.8b, v1.8b, v0.8b +; CHECK-NEXT: sshll v0.8h, v0.8b, #0 +; CHECK-NEXT: saddlv s0, v0.8h +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: neg w0, w8 +; CHECK-NEXT: ret + %mask = icmp slt <8 x i8> %a, %b + %t1 = bitcast <8 x i1> %mask to i8 + %t2 = call i8 @llvm.ctpop(i8 %t1) + %t3 = zext i8 %t2 to i64 + ret i64 %t3 +} + +define i64 @vmask_popcount_i64_v16i8(<16 x i8> %a, <16 x i8> %b) { +; CHECK-LABEL: vmask_popcount_i64_v16i8: +; CHECK: // %bb.0: +; CHECK-NEXT: cmgt v0.16b, v1.16b, v0.16b +; CHECK-NEXT: sshll2 v1.8h, v0.16b, #0 +; CHECK-NEXT: sshll v0.8h, v0.8b, #0 +; CHECK-NEXT: saddl2 v2.4s, v0.8h, v1.8h +; CHECK-NEXT: saddl v0.4s, v0.4h, v1.4h +; CHECK-NEXT: add v0.4s, v0.4s, v2.4s +; CHECK-NEXT: addv s0, v0.4s +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: neg w0, w8 +; CHECK-NEXT: ret + %mask = icmp slt <16 x i8> %a, %b + %t1 = bitcast <16 x i1> %mask to i16 + %t2 = call i16 @llvm.ctpop(i16 %t1) + %t3 = zext i16 %t2 to i64 + ret i64 %t3 +} + +define i64 @vmask_popcount_i64_v4i16(<4 x i16> %a, <4 x i16> %b) { +; CHECK-LABEL: vmask_popcount_i64_v4i16: +; CHECK: // %bb.0: +; CHECK-NEXT: cmgt v0.4h, v1.4h, v0.4h +; CHECK-NEXT: saddlv s0, v0.4h +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: neg w0, w8 +; CHECK-NEXT: ret + %mask = icmp slt <4 x i16> %a, %b + %t1 = bitcast <4 x i1> %mask to i4 + %t2 = call i4 @llvm.ctpop(i4 %t1) + %t3 = zext i4 %t2 to i64 + ret i64 %t3 +} + +define i64 @vmask_popcount_i64_v8i16(<8 x i16> %a, <8 x i16> %b) { +; CHECK-LABEL: vmask_popcount_i64_v8i16: +; CHECK: // %bb.0: +; CHECK-NEXT: cmgt v0.8h, v1.8h, v0.8h +; CHECK-NEXT: saddlv s0, v0.8h +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: neg w0, w8 +; CHECK-NEXT: ret + %mask = icmp slt <8 x i16> %a, %b + %t1 = bitcast <8 x i1> %mask to i8 + %t2 = call i8 @llvm.ctpop(i8 %t1) + %t3 = zext i8 %t2 to i64 + ret i64 %t3 +} + +define i64 @vmask_popcount_i64_v2i32(<2 x i32> %a, <2 x i32> %b) { +; CHECK-LABEL: vmask_popcount_i64_v2i32: +; CHECK: // %bb.0: +; CHECK-NEXT: cmgt v0.2s, v1.2s, v0.2s +; CHECK-NEXT: addp v0.2s, v0.2s, v0.2s +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: neg w0, w8 +; CHECK-NEXT: ret + %mask = icmp slt <2 x i32> %a, %b + %t1 = bitcast <2 x i1> %mask to i2 + %t2 = call i2 @llvm.ctpop(i2 %t1) + %t3 = zext i2 %t2 to i64 + ret i64 %t3 +} + +define i64 @vmask_popcount_i64_v4i32(<4 x i32> %a, <4 x i32> %b) { +; CHECK-LABEL: vmask_popcount_i64_v4i32: +; CHECK: // %bb.0: +; CHECK-NEXT: cmgt v0.4s, v1.4s, v0.4s +; CHECK-NEXT: addv s0, v0.4s +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: neg w0, w8 +; CHECK-NEXT: ret + %mask = icmp slt <4 x i32> %a, %b + %t1 = bitcast <4 x i1> %mask to i4 + %t2 = call i4 @llvm.ctpop(i4 %t1) + %t3 = zext i4 %t2 to i64 + ret i64 %t3 +} + +define i64 @vmask_popcount_i64_v1i64(<1 x i64> %a, <1 x i64> %b) { +; CHECK-LABEL: vmask_popcount_i64_v1i64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1 +; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0 +; CHECK-NEXT: fmov x8, d1 +; CHECK-NEXT: fmov x9, d0 +; CHECK-NEXT: cmp x9, x8 +; CHECK-NEXT: cset w0, lt +; CHECK-NEXT: ret + %mask = icmp slt <1 x i64> %a, %b + %t1 = bitcast <1 x i1> %mask to i1 + %t2 = call i1 @llvm.ctpop(i1 %t1) + %t3 = zext i1 %t2 to i64 + ret i64 %t3 +} + +define i64 @vmask_popcount_i64_v2i64(<2 x i64> %a, <2 x i64> %b) { +; CHECK-LABEL: vmask_popcount_i64_v2i64: +; CHECK: // %bb.0: +; CHECK-NEXT: cmgt v0.2d, v1.2d, v0.2d +; CHECK-NEXT: xtn v0.2s, v0.2d +; CHECK-NEXT: addp v0.2s, v0.2s, v0.2s +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: neg w0, w8 +; CHECK-NEXT: ret + %mask = icmp slt <2 x i64> %a, %b + %t1 = bitcast <2 x i1> %mask to i2 + %t2 = call i2 @llvm.ctpop(i2 %t1) + %t3 = zext i2 %t2 to i64 + ret i64 %t3 +} + +define i32 @non_vmask_popcount_1(half %a) { +; CHECK-LABEL: non_vmask_popcount_1: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $h0 killed $h0 def $s0 +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: and w8, w8, #0xffff +; CHECK-NEXT: fmov s0, w8 +; CHECK-NEXT: cnt v0.8b, v0.8b +; CHECK-NEXT: addv b0, v0.8b +; CHECK-NEXT: fmov w0, s0 +; CHECK-NEXT: ret + %t1 = bitcast half %a to i16 + %t2 = call i16 @llvm.ctpop(i16 %t1) + %t3 = zext i16 %t2 to i32 + ret i32 %t3 +} + +define i32 @non_vmask_popcount_2(<8 x i16> %a) { +; CHECK-LABEL: non_vmask_popcount_2: +; CHECK: // %bb.0: +; CHECK-NEXT: sub sp, sp, #16 +; CHECK-NEXT: .cfi_def_cfa_offset 16 +; CHECK-NEXT: xtn v0.8b, v0.8h +; CHECK-NEXT: umov w8, v0.b[0] +; CHECK-NEXT: umov w9, v0.b[1] +; CHECK-NEXT: umov w10, v0.b[2] +; CHECK-NEXT: and w8, w8, #0x3 +; CHECK-NEXT: bfi w8, w9, #2, #2 +; CHECK-NEXT: umov w9, v0.b[3] +; CHECK-NEXT: bfi w8, w10, #4, #2 +; CHECK-NEXT: umov w10, v0.b[4] +; CHECK-NEXT: bfi w8, w9, #6, #2 +; CHECK-NEXT: umov w9, v0.b[5] +; CHECK-NEXT: bfi w8, w10, #8, #2 +; CHECK-NEXT: umov w10, v0.b[6] +; CHECK-NEXT: bfi w8, w9, #10, #2 +; CHECK-NEXT: umov w9, v0.b[7] +; CHECK-NEXT: bfi w8, w10, #12, #2 +; CHECK-NEXT: orr w8, w8, w9, lsl #14 +; CHECK-NEXT: and w8, w8, #0xffff +; CHECK-NEXT: fmov s0, w8 +; CHECK-NEXT: cnt v0.8b, v0.8b +; CHECK-NEXT: addv b0, v0.8b +; CHECK-NEXT: fmov w0, s0 +; CHECK-NEXT: add sp, sp, #16 +; CHECK-NEXT: ret + %mask = trunc <8 x i16> %a to <8 x i2> + %t1 = bitcast <8 x i2> %mask to i16 + %t2 = call i16 @llvm.ctpop(i16 %t1) + %t3 = zext i16 %t2 to i32 + ret i32 %t3 +}