diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 2fc8b0c9a22cd..5c574b91e3ed0 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -19443,20 +19443,37 @@ AArch64TargetLowering::BuildSREMPow2(SDNode *N, const APInt &Divisor, return CSNeg; } -static std::optional IsSVECntIntrinsic(SDValue S) { +static bool IsSVECntIntrinsic(SDValue S) { switch(getIntrinsicID(S.getNode())) { default: break; case Intrinsic::aarch64_sve_cntb: - return 8; case Intrinsic::aarch64_sve_cnth: - return 16; case Intrinsic::aarch64_sve_cntw: - return 32; case Intrinsic::aarch64_sve_cntd: - return 64; + return true; + } + return false; +} + +// Returns the maximum (scalable) value that can be returned by an SVE count +// intrinsic. Returns std::nullopt if \p Op is not aarch64_sve_cnt*. +static std::optional getMaxValueForSVECntIntrinsic(SDValue Op) { + Intrinsic::ID IID = getIntrinsicID(Op.getNode()); + if (IID == Intrinsic::aarch64_sve_cntp) + return Op.getOperand(1).getValueType().getVectorElementCount(); + switch (IID) { + case Intrinsic::aarch64_sve_cntd: + return ElementCount::getScalable(2); + case Intrinsic::aarch64_sve_cntw: + return ElementCount::getScalable(4); + case Intrinsic::aarch64_sve_cnth: + return ElementCount::getScalable(8); + case Intrinsic::aarch64_sve_cntb: + return ElementCount::getScalable(16); + default: + return std::nullopt; } - return {}; } /// Calculates what the pre-extend type is, based on the extension @@ -31666,22 +31683,24 @@ bool AArch64TargetLowering::SimplifyDemandedBitsForTargetNode( return false; } case ISD::INTRINSIC_WO_CHAIN: { - if (auto ElementSize = IsSVECntIntrinsic(Op)) { - unsigned MaxSVEVectorSizeInBits = Subtarget->getMaxSVEVectorSizeInBits(); - if (!MaxSVEVectorSizeInBits) - MaxSVEVectorSizeInBits = AArch64::SVEMaxBitsPerVector; - unsigned MaxElements = MaxSVEVectorSizeInBits / *ElementSize; - // The SVE count intrinsics don't support the multiplier immediate so we - // don't have to account for that here. The value returned may be slightly - // over the true required bits, as this is based on the "ALL" pattern. The - // other patterns are also exposed by these intrinsics, but they all - // return a value that's strictly less than "ALL". - unsigned RequiredBits = llvm::bit_width(MaxElements); - unsigned BitWidth = Known.Zero.getBitWidth(); - if (RequiredBits < BitWidth) - Known.Zero.setHighBits(BitWidth - RequiredBits); + std::optional MaxCount = getMaxValueForSVECntIntrinsic(Op); + if (!MaxCount) return false; - } + unsigned MaxSVEVectorSizeInBits = Subtarget->getMaxSVEVectorSizeInBits(); + if (!MaxSVEVectorSizeInBits) + MaxSVEVectorSizeInBits = AArch64::SVEMaxBitsPerVector; + unsigned VscaleMax = MaxSVEVectorSizeInBits / 128; + unsigned MaxValue = MaxCount->getKnownMinValue() * VscaleMax; + // The SVE count intrinsics don't support the multiplier immediate so we + // don't have to account for that here. The value returned may be slightly + // over the true required bits, as this is based on the "ALL" pattern. The + // other patterns are also exposed by these intrinsics, but they all + // return a value that's strictly less than "ALL". + unsigned RequiredBits = llvm::bit_width(MaxValue); + unsigned BitWidth = Known.Zero.getBitWidth(); + if (RequiredBits < BitWidth) + Known.Zero.setHighBits(BitWidth - RequiredBits); + return false; } } diff --git a/llvm/test/CodeGen/AArch64/sve-vector-compress.ll b/llvm/test/CodeGen/AArch64/sve-vector-compress.ll index cc3a3734a9721..f700dee0fb2e4 100644 --- a/llvm/test/CodeGen/AArch64/sve-vector-compress.ll +++ b/llvm/test/CodeGen/AArch64/sve-vector-compress.ll @@ -143,20 +143,19 @@ define @test_compress_large( %vec, %p) { +; CHECK-LABEL: cntp_nxv16i1_and_elimination: +; CHECK: // %bb.0: +; CHECK-NEXT: cntp x8, p0, p0.b +; CHECK-NEXT: and x9, x8, #0x1fc +; CHECK-NEXT: add x0, x8, x9 +; CHECK-NEXT: ret + %cntp = tail call i64 @llvm.aarch64.sve.cntp.nxv16i1( %p, %p) + %and_redundant = and i64 %cntp, 511 + %and_required = and i64 %cntp, 17179869180 + %result = add i64 %and_redundant, %and_required + ret i64 %result +} + +define i64 @cntp_nxv8i1_and_elimination( %p) { +; CHECK-LABEL: cntp_nxv8i1_and_elimination: +; CHECK: // %bb.0: +; CHECK-NEXT: cntp x8, p0, p0.h +; CHECK-NEXT: and x9, x8, #0xfc +; CHECK-NEXT: add x0, x8, x9 +; CHECK-NEXT: ret + %cntp = tail call i64 @llvm.aarch64.sve.cntp.nxv8i1( %p, %p) + %and_redundant = and i64 %cntp, 1023 + %and_required = and i64 %cntp, 17179869180 + %result = add i64 %and_redundant, %and_required + ret i64 %result +} + +define i64 @cntp_nxv4i1_and_elimination( %p) { +; CHECK-LABEL: cntp_nxv4i1_and_elimination: +; CHECK: // %bb.0: +; CHECK-NEXT: cntp x8, p0, p0.s +; CHECK-NEXT: and x9, x8, #0x7c +; CHECK-NEXT: add x0, x8, x9 +; CHECK-NEXT: ret + %cntp = tail call i64 @llvm.aarch64.sve.cntp.nxv4i1( %p, %p) + %and_redundant = and i64 %cntp, 127 + %and_required = and i64 %cntp, 17179869180 + %result = add i64 %and_redundant, %and_required + ret i64 %result +} + +define i64 @cntp_nxv2i1_and_elimination( %p) { +; CHECK-LABEL: cntp_nxv2i1_and_elimination: +; CHECK: // %bb.0: +; CHECK-NEXT: cntp x8, p0, p0.d +; CHECK-NEXT: and x9, x8, #0x3c +; CHECK-NEXT: add x0, x8, x9 +; CHECK-NEXT: ret + %cntp = tail call i64 @llvm.aarch64.sve.cntp.nxv2i1( %p, %p) + %and_redundant = and i64 %cntp, 63 + %and_required = and i64 %cntp, 17179869180 + %result = add i64 %and_redundant, %and_required + ret i64 %result +} + define i64 @vscale_trunc_zext() vscale_range(1,16) { ; CHECK-LABEL: vscale_trunc_zext: ; CHECK: // %bb.0: