diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h index 9a6a9db65688a..4b2fda364fdf4 100644 --- a/llvm/include/llvm/IR/ConstantRange.h +++ b/llvm/include/llvm/IR/ConstantRange.h @@ -380,8 +380,9 @@ class [[nodiscard]] ConstantRange { /// Return a new range in the specified integer type, which must be /// strictly smaller than the current type. The returned range will /// correspond to the possible range of values if the source range had been - /// truncated to the specified type. - LLVM_ABI ConstantRange truncate(uint32_t BitWidth) const; + /// truncated to the specified type with wrap type \p NoWrapKind. + LLVM_ABI ConstantRange truncate(uint32_t BitWidth, + unsigned NoWrapKind = 0) const; /// Make this range have the bit width given by \p BitWidth. The /// value is zero extended, truncated, or left alone to make it that width. diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp index 2fcdbcc6a3db2..b454c9a4cd3ae 100644 --- a/llvm/lib/IR/ConstantRange.cpp +++ b/llvm/lib/IR/ConstantRange.cpp @@ -872,7 +872,8 @@ ConstantRange ConstantRange::signExtend(uint32_t DstTySize) const { return ConstantRange(Lower.sext(DstTySize), Upper.sext(DstTySize)); } -ConstantRange ConstantRange::truncate(uint32_t DstTySize) const { +ConstantRange ConstantRange::truncate(uint32_t DstTySize, + unsigned NoWrapKind) const { assert(getBitWidth() > DstTySize && "Not a value truncation"); if (isEmptySet()) return getEmpty(DstTySize); @@ -886,22 +887,36 @@ ConstantRange ConstantRange::truncate(uint32_t DstTySize) const { // We use the non-wrapped set code to analyze the [Lower, MaxValue) part, and // then we do the union with [MaxValue, Upper) if (isUpperWrapped()) { - // If Upper is greater than or equal to MaxValue(DstTy), it covers the whole - // truncated range. - if (Upper.getActiveBits() > DstTySize || Upper.countr_one() == DstTySize) + // If Upper is greater than MaxValue(DstTy), it covers the whole truncated + // range. + if (Upper.getActiveBits() > DstTySize) return getFull(DstTySize); - Union = ConstantRange(APInt::getMaxValue(DstTySize),Upper.trunc(DstTySize)); - UpperDiv.setAllBits(); - - // Union covers the MaxValue case, so return if the remaining range is just - // MaxValue(DstTy). - if (LowerDiv == UpperDiv) - return Union; + // For nuw the two parts are: [0, Upper) \/ [Lower, MaxValue(DstTy)] + if (NoWrapKind & TruncInst::NoUnsignedWrap) { + Union = ConstantRange(APInt::getZero(DstTySize), Upper.trunc(DstTySize)); + UpperDiv = APInt::getOneBitSet(getBitWidth(), DstTySize); + } else { + // If Upper is equal to MaxValue(DstTy), it covers the whole truncated + // range. + if (Upper.countr_one() == DstTySize) + return getFull(DstTySize); + Union = + ConstantRange(APInt::getMaxValue(DstTySize), Upper.trunc(DstTySize)); + UpperDiv.setAllBits(); + // Union covers the MaxValue case, so return if the remaining range is + // just MaxValue(DstTy). + if (LowerDiv == UpperDiv) + return Union; + } } // Chop off the most significant bits that are past the destination bitwidth. if (LowerDiv.getActiveBits() > DstTySize) { + // For trunc nuw if LowerDiv is greater than MaxValue(DstTy), the range is + // outside the whole truncated range. + if (NoWrapKind & TruncInst::NoUnsignedWrap) + return Union; // Mask to just the signficant bits and subtract from LowerDiv/UpperDiv. APInt Adjust = LowerDiv & APInt::getBitsSetFrom(getBitWidth(), DstTySize); LowerDiv -= Adjust; @@ -913,6 +928,10 @@ ConstantRange ConstantRange::truncate(uint32_t DstTySize) const { return ConstantRange(LowerDiv.trunc(DstTySize), UpperDiv.trunc(DstTySize)).unionWith(Union); + if (!LowerDiv.isZero() && NoWrapKind & TruncInst::NoUnsignedWrap) + return ConstantRange(LowerDiv.trunc(DstTySize), APInt::getZero(DstTySize)) + .unionWith(Union); + // The truncated value wraps around. Check if we can do better than fullset. if (UpperDivWidth == DstTySize + 1) { // Clear the MSB so that UpperDiv wraps around. diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp index bcb5d498c8cb9..53d581c8db7c9 100644 --- a/llvm/unittests/IR/ConstantRangeTest.cpp +++ b/llvm/unittests/IR/ConstantRangeTest.cpp @@ -451,6 +451,65 @@ TEST_F(ConstantRangeTest, Trunc) { EXPECT_EQ(SevenOne.truncate(2), ConstantRange(APInt(2, 3), APInt(2, 1))); } +TEST_F(ConstantRangeTest, TruncNuw) { + auto Range = [](unsigned NumBits, unsigned Lower, unsigned Upper) { + return ConstantRange(APInt(NumBits, Lower), APInt(NumBits, Upper)); + }; + // trunc([0, 4), 3->2) = full + EXPECT_TRUE( + Range(3, 0, 4).truncate(2, TruncInst::NoUnsignedWrap).isFullSet()); + // trunc([0, 3), 3->2) = [0, 3) + EXPECT_EQ(Range(3, 0, 3).truncate(2, TruncInst::NoUnsignedWrap), + Range(2, 0, 3)); + // trunc([1, 3), 3->2) = [1, 3) + EXPECT_EQ(Range(3, 1, 3).truncate(2, TruncInst::NoUnsignedWrap), + Range(2, 1, 3)); + // trunc([1, 5), 3->2) = [1, 0) + EXPECT_EQ(Range(3, 1, 5).truncate(2, TruncInst::NoUnsignedWrap), + Range(2, 1, 0)); + // trunc([4, 7), 3->2) = empty + EXPECT_TRUE( + Range(3, 4, 7).truncate(2, TruncInst::NoUnsignedWrap).isEmptySet()); + // trunc([4, 0), 3->2) = empty + EXPECT_TRUE( + Range(3, 4, 0).truncate(2, TruncInst::NoUnsignedWrap).isEmptySet()); + // trunc([4, 1), 3->2) = [0, 1) + EXPECT_EQ(Range(3, 4, 1).truncate(2, TruncInst::NoUnsignedWrap), + Range(2, 0, 1)); + // trunc([3, 1), 3->2) = [3, 1) + EXPECT_EQ(Range(3, 3, 1).truncate(2, TruncInst::NoUnsignedWrap), + Range(2, 3, 1)); + // trunc([3, 0), 3->2) = [3, 0) + EXPECT_EQ(Range(3, 3, 0).truncate(2, TruncInst::NoUnsignedWrap), + Range(2, 3, 0)); + // trunc([1, 0), 2->1) = [1, 0) + EXPECT_EQ(Range(2, 1, 0).truncate(1, TruncInst::NoUnsignedWrap), + Range(1, 1, 0)); + // trunc([2, 1), 2->1) = [0, 1) + EXPECT_EQ(Range(2, 2, 1).truncate(1, TruncInst::NoUnsignedWrap), + Range(1, 0, 1)); +} + +TEST_F(ConstantRangeTest, TruncNuwExhaustive) { + EnumerateConstantRanges(4, [&](const ConstantRange &CR) { + unsigned NumBits = 3; + ConstantRange Trunc = CR.truncate(NumBits, TruncInst::NoUnsignedWrap); + SmallBitVector Elems(1 << NumBits); + ForeachNumInConstantRange(CR, [&](const APInt &N) { + if (N.isIntN(NumBits)) + Elems.set(N.getZExtValue()); + }); + TestRange(Trunc, Elems, PreferSmallest, {CR}); + }); + EnumerateConstantRanges(3, [&](const ConstantRange &CR) { + ConstantRange Trunc = CR.truncate(1, TruncInst::NoUnsignedWrap); + EXPECT_EQ(CR.contains(APInt::getZero(3)), + Trunc.contains(APInt::getZero(1))); + EXPECT_EQ(CR.contains(APInt::getOneBitSet(3, 0)), + Trunc.contains(APInt::getAllOnes(1))); + }); +} + TEST_F(ConstantRangeTest, ZExt) { ConstantRange ZFull = Full.zeroExtend(20); ConstantRange ZEmpty = Empty.zeroExtend(20);