From 103684b8e84441e2e049e2b2db37c19e6628e9e8 Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Fri, 26 May 2023 14:03:12 +0200 Subject: [PATCH] [KnownBits] Partially synchronize shift implementations (NFC) And remove some bits of effectively dead code. --- llvm/lib/Support/KnownBits.cpp | 124 ++++++++++++--------------------- 1 file changed, 44 insertions(+), 80 deletions(-) diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp index c8e4a8981666a..a7ca7c05dad2b 100644 --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -199,11 +199,6 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW, KnownBits Known(BitWidth); unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth); if (LHS.isUnknown()) { - if (MinShiftAmount == BitWidth) { - // Always poison. Return zero because we don't like returning conflict. - Known.setAllZero(); - return Known; - } Known.Zero.setLowBits(MinShiftAmount); if (NUW && NSW && MinShiftAmount != 0) Known.makeNonNegative(); @@ -261,120 +256,89 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW, KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) { unsigned BitWidth = LHS.getBitWidth(); - KnownBits Known(BitWidth); - - if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) { - unsigned Shift = RHS.getConstant().getZExtValue(); - Known = LHS; - Known.Zero.lshrInPlace(Shift); - Known.One.lshrInPlace(Shift); + auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) { + KnownBits Known = LHS; + Known.Zero.lshrInPlace(ShiftAmt); + Known.One.lshrInPlace(ShiftAmt); // High bits are known zero. - Known.Zero.setHighBits(Shift); + Known.Zero.setHighBits(ShiftAmt); return Known; - } - - // Minimum shift amount high bits are known zero. - APInt MinShiftAmount = RHS.getMinValue(); - if (MinShiftAmount.uge(BitWidth)) { - // Always poison. Return zero because we don't like returning conflict. - Known.setAllZero(); - return Known; - } + }; + // Fast path for a common case when LHS is completely unknown. + KnownBits Known(BitWidth); + unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth); if (LHS.isUnknown()) { - // No matter the shift amount, the leading zeros will stay zero. - unsigned MinLeadingZeros = LHS.countMinLeadingZeros(); - MinLeadingZeros += MinShiftAmount.getZExtValue(); - MinLeadingZeros = std::min(MinLeadingZeros, BitWidth); - Known.Zero.setHighBits(MinLeadingZeros); + Known.Zero.setHighBits(MinShiftAmount); return Known; } // Find the common bits from all possible shifts. - APInt MaxShiftAmount = RHS.getMaxValue(); - uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue(); - uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue(); - assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range"); + APInt MaxValue = RHS.getMaxValue(); + unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth); + unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue(); + unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue(); Known.Zero.setAllBits(); Known.One.setAllBits(); - for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(), - MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1); - ShiftAmt <= MaxShiftAmt; ++ShiftAmt) { + for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount; + ++ShiftAmt) { // Skip if the shift amount is impossible. - if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt || + if ((ShiftAmtZeroMask & ShiftAmt) != 0 || (ShiftAmtOneMask | ShiftAmt) != ShiftAmt) continue; - KnownBits SpecificShift = LHS; - SpecificShift.Zero.lshrInPlace(ShiftAmt); - SpecificShift.Zero.setHighBits(ShiftAmt); - SpecificShift.One.lshrInPlace(ShiftAmt); - Known = Known.intersectWith(SpecificShift); + Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt)); if (Known.isUnknown()) break; } + // All shift amounts may result in poison. + if (Known.hasConflict()) + Known.setAllZero(); return Known; } KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) { unsigned BitWidth = LHS.getBitWidth(); - KnownBits Known(BitWidth); - - if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) { - unsigned Shift = RHS.getConstant().getZExtValue(); - Known = LHS; - Known.Zero.ashrInPlace(Shift); - Known.One.ashrInPlace(Shift); - return Known; - } - - // Minimum shift amount high bits are known sign bits. - APInt MinShiftAmount = RHS.getMinValue(); - if (MinShiftAmount.uge(BitWidth)) { - // Always poison. Return zero because we don't like returning conflict. - Known.setAllZero(); + auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) { + KnownBits Known = LHS; + Known.Zero.ashrInPlace(ShiftAmt); + Known.One.ashrInPlace(ShiftAmt); return Known; - } + }; + // Fast path for a common case when LHS is completely unknown. + KnownBits Known(BitWidth); + unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth); if (LHS.isUnknown()) { - // No matter the shift amount, the leading sign bits will stay. - unsigned MinLeadingZeros = LHS.countMinLeadingZeros(); - unsigned MinLeadingOnes = LHS.countMinLeadingOnes(); - if (MinLeadingZeros) { - MinLeadingZeros += MinShiftAmount.getZExtValue(); - MinLeadingZeros = std::min(MinLeadingZeros, BitWidth); - } - if (MinLeadingOnes) { - MinLeadingOnes += MinShiftAmount.getZExtValue(); - MinLeadingOnes = std::min(MinLeadingOnes, BitWidth); + if (MinShiftAmount == BitWidth) { + // Always poison. Return zero because we don't like returning conflict. + Known.setAllZero(); + return Known; } - Known.Zero.setHighBits(MinLeadingZeros); - Known.One.setHighBits(MinLeadingOnes); return Known; } // Find the common bits from all possible shifts. - APInt MaxShiftAmount = RHS.getMaxValue(); - uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue(); - uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue(); - assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range"); + APInt MaxValue = RHS.getMaxValue(); + unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth); + unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue(); + unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue(); Known.Zero.setAllBits(); Known.One.setAllBits(); - for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(), - MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1); - ShiftAmt <= MaxShiftAmt; ++ShiftAmt) { + for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount; + ++ShiftAmt) { // Skip if the shift amount is impossible. - if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt || + if ((ShiftAmtZeroMask & ShiftAmt) != 0 || (ShiftAmtOneMask | ShiftAmt) != ShiftAmt) continue; - KnownBits SpecificShift = LHS; - SpecificShift.Zero.ashrInPlace(ShiftAmt); - SpecificShift.One.ashrInPlace(ShiftAmt); - Known = Known.intersectWith(SpecificShift); + Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt)); if (Known.isUnknown()) break; } + // All shift amounts may result in poison. + if (Known.hasConflict()) + Known.setAllZero(); return Known; }