-
Notifications
You must be signed in to change notification settings - Fork 11.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ConstantRange] Handle Intrinsic::cttz
#67917
Conversation
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-llvm-ir ChangesThis patch adds support for cttz and ctpop intrinsics in ConstantRange. It calculates the range in O(1) with the LCP-based method. Migrated from https://reviews.llvm.org/D153505. Full diff: https://github.com/llvm/llvm-project/pull/67917.diff 4 Files Affected:
diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h
index ca36732e4e2e8c2..e718e6e7e3403de 100644
--- a/llvm/include/llvm/IR/ConstantRange.h
+++ b/llvm/include/llvm/IR/ConstantRange.h
@@ -530,6 +530,13 @@ class [[nodiscard]] ConstantRange {
/// ignoring a possible zero value contained in the input range.
ConstantRange ctlz(bool ZeroIsPoison = false) const;
+ /// Calculate cttz range. If \p ZeroIsPoison is set, the range is computed
+ /// ignoring a possible zero value contained in the input range.
+ ConstantRange cttz(bool ZeroIsPoison = false) const;
+
+ /// Calculate ctpop range.
+ ConstantRange ctpop() const;
+
/// Represents whether an operation on the given constant range is known to
/// always or never overflow.
enum class OverflowResult {
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 3d71b20f7e853e0..f34a2749543c321 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -949,6 +949,8 @@ bool ConstantRange::isIntrinsicSupported(Intrinsic::ID IntrinsicID) {
case Intrinsic::smax:
case Intrinsic::abs:
case Intrinsic::ctlz:
+ case Intrinsic::cttz:
+ case Intrinsic::ctpop:
return true;
default:
return false;
@@ -986,6 +988,15 @@ ConstantRange ConstantRange::intrinsic(Intrinsic::ID IntrinsicID,
assert(ZeroIsPoison->getBitWidth() == 1 && "Must be boolean");
return Ops[0].ctlz(ZeroIsPoison->getBoolValue());
}
+ case Intrinsic::cttz: {
+ const APInt *ZeroIsPoison = Ops[1].getSingleElement();
+ assert(ZeroIsPoison && "Must be known (immarg)");
+ assert(ZeroIsPoison->getBitWidth() == 1 && "Must be boolean");
+ return Ops[0].cttz(ZeroIsPoison->getBoolValue());
+ }
+ case Intrinsic::ctpop: {
+ return Ops[0].ctpop();
+ }
default:
assert(!isIntrinsicSupported(IntrinsicID) && "Shouldn't be supported");
llvm_unreachable("Unsupported intrinsic");
@@ -1735,6 +1746,122 @@ ConstantRange ConstantRange::ctlz(bool ZeroIsPoison) const {
return getNonEmpty(APInt(getBitWidth(), getUnsignedMax().countl_zero()),
APInt(getBitWidth(), getUnsignedMin().countl_zero() + 1));
}
+static ConstantRange getUnsignedCountTrailingZerosRange(const APInt &Lower,
+ const APInt &Upper) {
+ assert(Lower.ule(Upper));
+ unsigned BitWidth = Lower.getBitWidth();
+ if (Lower == Upper)
+ return ConstantRange::getEmpty(BitWidth);
+ if (Lower + 1 == Upper)
+ return ConstantRange(APInt(BitWidth, Lower.countr_zero()));
+ if (Lower.isZero())
+ return ConstantRange(APInt::getZero(BitWidth),
+ APInt(BitWidth, BitWidth + 1));
+
+ // Calculate longest common prefix.
+ unsigned LCPLength = (Lower ^ (Upper - 1)).countl_zero();
+ // If Lower is {LCP, 000...}, the maximum is Lower.countr_zero().
+ // Otherwise, the maximum is BitWidth - LCPLength - 1 ({LCP, 100...}).
+ return ConstantRange(
+ APInt::getZero(BitWidth),
+ APInt(BitWidth, std::max(BitWidth - LCPLength, Lower.countr_zero() + 1)));
+}
+
+ConstantRange ConstantRange::cttz(bool ZeroIsPoison) const {
+ if (isEmptySet())
+ return getEmpty();
+
+ APInt Zero = APInt::getZero(getBitWidth());
+
+ if (ZeroIsPoison && contains(Zero)) {
+ // ZeroIsPoison is set, and zero is contained. We discern three cases, in
+ // which a zero can appear:
+ // 1) Lower is zero, handling cases of kind [0, 1), [0, 2), etc.
+ // 2) Upper is zero, wrapped set, handling cases of kind [3, 0], etc.
+ // 3) Zero contained in a wrapped set, e.g., [3, 2), [3, 1), etc.
+
+ if (getLower().isZero()) {
+ if ((getUpper() - 1).isZero()) {
+ // We have in input interval of kind [0, 1). In this case we cannot
+ // really help but return empty-set.
+ return getEmpty();
+ }
+
+ // Compute the resulting range by excluding zero from Lower.
+ return getUnsignedCountTrailingZerosRange(getLower() + 1, getUpper());
+ } else if ((getUpper() - 1).isZero()) {
+ // Compute the resulting range by excluding zero from Upper.
+ return ConstantRange(
+ Zero, APInt(getBitWidth(),
+ (getUnsignedMax() - getLower() + 1).logBase2() + 1));
+ } else {
+ ConstantRange CR1(
+ Zero, APInt(getBitWidth(),
+ (getUnsignedMax() - getLower() + 1).logBase2() + 1));
+ ConstantRange CR2 = getUnsignedCountTrailingZerosRange(
+ APInt(getBitWidth(), 1), getUpper());
+ return CR1.unionWith(CR2);
+ }
+ }
+
+ if (isFullSet()) {
+ return getNonEmpty(Zero, APInt(getBitWidth(), getBitWidth() + 1));
+ }
+ if (!isUpperWrapped()) {
+ return getUnsignedCountTrailingZerosRange(getLower(), getUpper());
+ }
+ ConstantRange CR1(
+ Zero,
+ APInt(getBitWidth(), (getUnsignedMax() - getLower() + 1).logBase2() + 1));
+ ConstantRange CR2 = getUnsignedCountTrailingZerosRange(Zero, getUpper());
+ return CR1.unionWith(CR2);
+}
+
+static ConstantRange getUnsignedPopCountRange(const APInt &Lower,
+ const APInt &Upper) {
+ assert(Lower.ule(Upper));
+ unsigned BitWidth = Lower.getBitWidth();
+ if (Lower == Upper)
+ return ConstantRange::getEmpty(BitWidth);
+ if (Lower + 1 == Upper)
+ return ConstantRange(APInt(BitWidth, Lower.popcount()));
+
+ APInt Max = Upper - 1;
+ // Calculate longest common prefix.
+ unsigned LCPLength = (Lower ^ Max).countl_zero();
+ unsigned LCPPopCount = Lower.getHiBits(LCPLength).popcount();
+ // If Lower is {LCP, 000...}, the minimum is the popcount of LCP.
+ // Otherwise, the minimum is the popcount of LCP + 1.
+ unsigned MinBits =
+ LCPPopCount + (Lower.countr_zero() < BitWidth - LCPLength ? 1 : 0);
+ // If Max is {LCP, 111...}, the maximum is the popcount of LCP + (BitWidth -
+ // length of LCP).
+ // Otherwise, the minimum is the popcount of LCP + (BitWidth -
+ // length of LCP - 1).
+ unsigned MaxBits = LCPPopCount + (BitWidth - LCPLength) +
+ (Max.countr_one() >= BitWidth - LCPLength ? 1 : 0);
+ return ConstantRange(APInt(BitWidth, MinBits), APInt(BitWidth, MaxBits));
+}
+
+ConstantRange ConstantRange::ctpop() const {
+ if (isEmptySet())
+ return getEmpty();
+
+ unsigned BitWidth = getBitWidth();
+ APInt Zero = APInt::getZero(BitWidth);
+ if (isFullSet()) {
+ return getNonEmpty(Zero, APInt(BitWidth, BitWidth + 1));
+ }
+ if (!isUpperWrapped()) {
+ return getUnsignedPopCountRange(getLower(), getUpper());
+ }
+ ConstantRange CR1 = ConstantRange(
+ APInt(BitWidth,
+ BitWidth - (getUnsignedMax() - getLower() + 1).logBase2()),
+ APInt(BitWidth, BitWidth + 1)); // [lower, intmax]
+ ConstantRange CR2 = getUnsignedPopCountRange(Zero, getUpper()); // [0, upper)
+ return CR1.unionWith(CR2);
+}
ConstantRange::OverflowResult ConstantRange::unsignedAddMayOverflow(
const ConstantRange &Other) const {
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/range.ll b/llvm/test/Transforms/CorrelatedValuePropagation/range.ll
index 7e89f864c8110ee..182a0bbef255de8 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/range.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/range.ll
@@ -1010,6 +1010,60 @@ else:
ret i1 %res2
}
+define i1 @cttz_fold(i16 %x) {
+; CHECK-LABEL: @cttz_fold(
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[X:%.*]], 256
+; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[ELSE:%.*]]
+; CHECK: if:
+; CHECK-NEXT: [[CTTZ:%.*]] = call i16 @llvm.cttz.i16(i16 [[X]], i1 true)
+; CHECK-NEXT: ret i1 false
+; CHECK: else:
+; CHECK-NEXT: [[CTTZ2:%.*]] = call i16 @llvm.cttz.i16(i16 [[X]], i1 true)
+; CHECK-NEXT: [[RES2:%.*]] = icmp ult i16 [[CTTZ2]], 8
+; CHECK-NEXT: ret i1 [[RES2]]
+;
+ %cmp = icmp ult i16 %x, 256
+ br i1 %cmp, label %if, label %else
+
+if:
+ %cttz = call i16 @llvm.cttz.i16(i16 %x, i1 true)
+ %res = icmp uge i16 %cttz, 8
+ ret i1 %res
+
+else:
+ %cttz2 = call i16 @llvm.cttz.i16(i16 %x, i1 true)
+ %res2 = icmp ult i16 %cttz2, 8
+ ret i1 %res2
+}
+
+define i1 @ctpop_fold(i16 %x) {
+; CHECK-LABEL: @ctpop_fold(
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[X:%.*]], 256
+; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[ELSE:%.*]]
+; CHECK: if:
+; CHECK-NEXT: [[CTPOP:%.*]] = call i16 @llvm.ctpop.i16(i16 [[X]])
+; CHECK-NEXT: ret i1 true
+; CHECK: else:
+; CHECK-NEXT: [[CTPOP2:%.*]] = call i16 @llvm.ctpop.i16(i16 [[X]])
+; CHECK-NEXT: [[RES2:%.*]] = icmp ugt i16 [[CTPOP2]], 8
+; CHECK-NEXT: ret i1 [[RES2]]
+;
+ %cmp = icmp ult i16 %x, 256
+ br i1 %cmp, label %if, label %else
+
+if:
+ %ctpop = call i16 @llvm.ctpop.i16(i16 %x)
+ %res = icmp ule i16 %ctpop, 8
+ ret i1 %res
+
+else:
+ %ctpop2 = call i16 @llvm.ctpop.i16(i16 %x)
+ %res2 = icmp ugt i16 %ctpop2, 8
+ ret i1 %res2
+}
+
declare i16 @llvm.ctlz.i16(i16, i1)
+declare i16 @llvm.cttz.i16(i16, i1)
+declare i16 @llvm.ctpop.i16(i16)
declare i16 @llvm.abs.i16(i16, i1)
declare void @llvm.assume(i1)
diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp
index 1cb358a26062ca5..e505af5d3275ef2 100644
--- a/llvm/unittests/IR/ConstantRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeTest.cpp
@@ -2438,6 +2438,26 @@ TEST_F(ConstantRangeTest, Ctlz) {
});
}
+TEST_F(ConstantRangeTest, Cttz) {
+ TestUnaryOpExhaustive(
+ [](const ConstantRange &CR) { return CR.cttz(); },
+ [](const APInt &N) { return APInt(N.getBitWidth(), N.countr_zero()); });
+
+ TestUnaryOpExhaustive(
+ [](const ConstantRange &CR) { return CR.cttz(/*ZeroIsPoison=*/true); },
+ [](const APInt &N) -> std::optional<APInt> {
+ if (N.isZero())
+ return std::nullopt;
+ return APInt(N.getBitWidth(), N.countr_zero());
+ });
+}
+
+TEST_F(ConstantRangeTest, Ctpop) {
+ TestUnaryOpExhaustive(
+ [](const ConstantRange &CR) { return CR.ctpop(); },
+ [](const APInt &N) { return APInt(N.getBitWidth(), N.popcount()); });
+}
+
TEST_F(ConstantRangeTest, castOps) {
ConstantRange A(APInt(16, 66), APInt(16, 128));
ConstantRange FpToI8 = A.castOp(Instruction::FPToSI, 8);
|
Please split this into separate patches for cttz and ctpop. |
Intrinsic::cttz
and Intrinsic::ctpop
Intrinsic::cttz
Ping. |
LGTM. |
@nikic Any comments about this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think most of the comments from the ctpop review also apply to this one.
a2ea2e9
to
8860c06
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
64565e6
to
9d16764
Compare
This patch adds support for
Intrinsic::cttz
in ConstantRange. It calculates the range in O(1) with the LCP-based method.Migrated from https://reviews.llvm.org/D153505.