Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ConstantRange] Handle Intrinsic::cttz #67917

Merged
merged 2 commits into from
Nov 6, 2023

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Oct 1, 2023

This patch adds support for Intrinsic::cttz in ConstantRange. It calculates the range in O(1) with the LCP-based method.

Migrated from https://reviews.llvm.org/D153505.

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 1, 2023

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-ir

Changes

This patch adds support for cttz and ctpop intrinsics in ConstantRange. It calculates the range in O(1) with the LCP-based method.

Migrated from https://reviews.llvm.org/D153505.


Full diff: https://github.com/llvm/llvm-project/pull/67917.diff

4 Files Affected:

  • (modified) llvm/include/llvm/IR/ConstantRange.h (+7)
  • (modified) llvm/lib/IR/ConstantRange.cpp (+127)
  • (modified) llvm/test/Transforms/CorrelatedValuePropagation/range.ll (+54)
  • (modified) llvm/unittests/IR/ConstantRangeTest.cpp (+20)
diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h
index ca36732e4e2e8c2..e718e6e7e3403de 100644
--- a/llvm/include/llvm/IR/ConstantRange.h
+++ b/llvm/include/llvm/IR/ConstantRange.h
@@ -530,6 +530,13 @@ class [[nodiscard]] ConstantRange {
   /// ignoring a possible zero value contained in the input range.
   ConstantRange ctlz(bool ZeroIsPoison = false) const;
 
+  /// Calculate cttz range. If \p ZeroIsPoison is set, the range is computed
+  /// ignoring a possible zero value contained in the input range.
+  ConstantRange cttz(bool ZeroIsPoison = false) const;
+
+  /// Calculate ctpop range.
+  ConstantRange ctpop() const;
+
   /// Represents whether an operation on the given constant range is known to
   /// always or never overflow.
   enum class OverflowResult {
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 3d71b20f7e853e0..f34a2749543c321 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -949,6 +949,8 @@ bool ConstantRange::isIntrinsicSupported(Intrinsic::ID IntrinsicID) {
   case Intrinsic::smax:
   case Intrinsic::abs:
   case Intrinsic::ctlz:
+  case Intrinsic::cttz:
+  case Intrinsic::ctpop:
     return true;
   default:
     return false;
@@ -986,6 +988,15 @@ ConstantRange ConstantRange::intrinsic(Intrinsic::ID IntrinsicID,
     assert(ZeroIsPoison->getBitWidth() == 1 && "Must be boolean");
     return Ops[0].ctlz(ZeroIsPoison->getBoolValue());
   }
+  case Intrinsic::cttz: {
+    const APInt *ZeroIsPoison = Ops[1].getSingleElement();
+    assert(ZeroIsPoison && "Must be known (immarg)");
+    assert(ZeroIsPoison->getBitWidth() == 1 && "Must be boolean");
+    return Ops[0].cttz(ZeroIsPoison->getBoolValue());
+  }
+  case Intrinsic::ctpop: {
+    return Ops[0].ctpop();
+  }
   default:
     assert(!isIntrinsicSupported(IntrinsicID) && "Shouldn't be supported");
     llvm_unreachable("Unsupported intrinsic");
@@ -1735,6 +1746,122 @@ ConstantRange ConstantRange::ctlz(bool ZeroIsPoison) const {
   return getNonEmpty(APInt(getBitWidth(), getUnsignedMax().countl_zero()),
                      APInt(getBitWidth(), getUnsignedMin().countl_zero() + 1));
 }
+static ConstantRange getUnsignedCountTrailingZerosRange(const APInt &Lower,
+                                                        const APInt &Upper) {
+  assert(Lower.ule(Upper));
+  unsigned BitWidth = Lower.getBitWidth();
+  if (Lower == Upper)
+    return ConstantRange::getEmpty(BitWidth);
+  if (Lower + 1 == Upper)
+    return ConstantRange(APInt(BitWidth, Lower.countr_zero()));
+  if (Lower.isZero())
+    return ConstantRange(APInt::getZero(BitWidth),
+                         APInt(BitWidth, BitWidth + 1));
+
+  // Calculate longest common prefix.
+  unsigned LCPLength = (Lower ^ (Upper - 1)).countl_zero();
+  // If Lower is {LCP, 000...}, the maximum is Lower.countr_zero().
+  // Otherwise, the maximum is BitWidth - LCPLength - 1 ({LCP, 100...}).
+  return ConstantRange(
+      APInt::getZero(BitWidth),
+      APInt(BitWidth, std::max(BitWidth - LCPLength, Lower.countr_zero() + 1)));
+}
+
+ConstantRange ConstantRange::cttz(bool ZeroIsPoison) const {
+  if (isEmptySet())
+    return getEmpty();
+
+  APInt Zero = APInt::getZero(getBitWidth());
+
+  if (ZeroIsPoison && contains(Zero)) {
+    // ZeroIsPoison is set, and zero is contained. We discern three cases, in
+    // which a zero can appear:
+    // 1) Lower is zero, handling cases of kind [0, 1), [0, 2), etc.
+    // 2) Upper is zero, wrapped set, handling cases of kind [3, 0], etc.
+    // 3) Zero contained in a wrapped set, e.g., [3, 2), [3, 1), etc.
+
+    if (getLower().isZero()) {
+      if ((getUpper() - 1).isZero()) {
+        // We have in input interval of kind [0, 1). In this case we cannot
+        // really help but return empty-set.
+        return getEmpty();
+      }
+
+      // Compute the resulting range by excluding zero from Lower.
+      return getUnsignedCountTrailingZerosRange(getLower() + 1, getUpper());
+    } else if ((getUpper() - 1).isZero()) {
+      // Compute the resulting range by excluding zero from Upper.
+      return ConstantRange(
+          Zero, APInt(getBitWidth(),
+                      (getUnsignedMax() - getLower() + 1).logBase2() + 1));
+    } else {
+      ConstantRange CR1(
+          Zero, APInt(getBitWidth(),
+                      (getUnsignedMax() - getLower() + 1).logBase2() + 1));
+      ConstantRange CR2 = getUnsignedCountTrailingZerosRange(
+          APInt(getBitWidth(), 1), getUpper());
+      return CR1.unionWith(CR2);
+    }
+  }
+
+  if (isFullSet()) {
+    return getNonEmpty(Zero, APInt(getBitWidth(), getBitWidth() + 1));
+  }
+  if (!isUpperWrapped()) {
+    return getUnsignedCountTrailingZerosRange(getLower(), getUpper());
+  }
+  ConstantRange CR1(
+      Zero,
+      APInt(getBitWidth(), (getUnsignedMax() - getLower() + 1).logBase2() + 1));
+  ConstantRange CR2 = getUnsignedCountTrailingZerosRange(Zero, getUpper());
+  return CR1.unionWith(CR2);
+}
+
+static ConstantRange getUnsignedPopCountRange(const APInt &Lower,
+                                              const APInt &Upper) {
+  assert(Lower.ule(Upper));
+  unsigned BitWidth = Lower.getBitWidth();
+  if (Lower == Upper)
+    return ConstantRange::getEmpty(BitWidth);
+  if (Lower + 1 == Upper)
+    return ConstantRange(APInt(BitWidth, Lower.popcount()));
+
+  APInt Max = Upper - 1;
+  // Calculate longest common prefix.
+  unsigned LCPLength = (Lower ^ Max).countl_zero();
+  unsigned LCPPopCount = Lower.getHiBits(LCPLength).popcount();
+  // If Lower is {LCP, 000...}, the minimum is the popcount of LCP.
+  // Otherwise, the minimum is the popcount of LCP + 1.
+  unsigned MinBits =
+      LCPPopCount + (Lower.countr_zero() < BitWidth - LCPLength ? 1 : 0);
+  // If Max is {LCP, 111...}, the maximum is the popcount of LCP + (BitWidth -
+  // length of LCP).
+  // Otherwise, the minimum is the popcount of LCP + (BitWidth -
+  // length of LCP - 1).
+  unsigned MaxBits = LCPPopCount + (BitWidth - LCPLength) +
+                     (Max.countr_one() >= BitWidth - LCPLength ? 1 : 0);
+  return ConstantRange(APInt(BitWidth, MinBits), APInt(BitWidth, MaxBits));
+}
+
+ConstantRange ConstantRange::ctpop() const {
+  if (isEmptySet())
+    return getEmpty();
+
+  unsigned BitWidth = getBitWidth();
+  APInt Zero = APInt::getZero(BitWidth);
+  if (isFullSet()) {
+    return getNonEmpty(Zero, APInt(BitWidth, BitWidth + 1));
+  }
+  if (!isUpperWrapped()) {
+    return getUnsignedPopCountRange(getLower(), getUpper());
+  }
+  ConstantRange CR1 = ConstantRange(
+      APInt(BitWidth,
+            BitWidth - (getUnsignedMax() - getLower() + 1).logBase2()),
+      APInt(BitWidth, BitWidth + 1)); // [lower, intmax]
+  ConstantRange CR2 = getUnsignedPopCountRange(Zero, getUpper()); // [0, upper)
+  return CR1.unionWith(CR2);
+}
 
 ConstantRange::OverflowResult ConstantRange::unsignedAddMayOverflow(
     const ConstantRange &Other) const {
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/range.ll b/llvm/test/Transforms/CorrelatedValuePropagation/range.ll
index 7e89f864c8110ee..182a0bbef255de8 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/range.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/range.ll
@@ -1010,6 +1010,60 @@ else:
   ret i1 %res2
 }
 
+define i1 @cttz_fold(i16 %x) {
+; CHECK-LABEL: @cttz_fold(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i16 [[X:%.*]], 256
+; CHECK-NEXT:    br i1 [[CMP]], label [[IF:%.*]], label [[ELSE:%.*]]
+; CHECK:       if:
+; CHECK-NEXT:    [[CTTZ:%.*]] = call i16 @llvm.cttz.i16(i16 [[X]], i1 true)
+; CHECK-NEXT:    ret i1 false
+; CHECK:       else:
+; CHECK-NEXT:    [[CTTZ2:%.*]] = call i16 @llvm.cttz.i16(i16 [[X]], i1 true)
+; CHECK-NEXT:    [[RES2:%.*]] = icmp ult i16 [[CTTZ2]], 8
+; CHECK-NEXT:    ret i1 [[RES2]]
+;
+  %cmp = icmp ult i16 %x, 256
+  br i1 %cmp, label %if, label %else
+
+if:
+  %cttz = call i16 @llvm.cttz.i16(i16 %x, i1 true)
+  %res = icmp uge i16 %cttz, 8
+  ret i1 %res
+
+else:
+  %cttz2 = call i16 @llvm.cttz.i16(i16 %x, i1 true)
+  %res2 = icmp ult i16 %cttz2, 8
+  ret i1 %res2
+}
+
+define i1 @ctpop_fold(i16 %x) {
+; CHECK-LABEL: @ctpop_fold(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i16 [[X:%.*]], 256
+; CHECK-NEXT:    br i1 [[CMP]], label [[IF:%.*]], label [[ELSE:%.*]]
+; CHECK:       if:
+; CHECK-NEXT:    [[CTPOP:%.*]] = call i16 @llvm.ctpop.i16(i16 [[X]])
+; CHECK-NEXT:    ret i1 true
+; CHECK:       else:
+; CHECK-NEXT:    [[CTPOP2:%.*]] = call i16 @llvm.ctpop.i16(i16 [[X]])
+; CHECK-NEXT:    [[RES2:%.*]] = icmp ugt i16 [[CTPOP2]], 8
+; CHECK-NEXT:    ret i1 [[RES2]]
+;
+  %cmp = icmp ult i16 %x, 256
+  br i1 %cmp, label %if, label %else
+
+if:
+  %ctpop = call i16 @llvm.ctpop.i16(i16 %x)
+  %res = icmp ule i16 %ctpop, 8
+  ret i1 %res
+
+else:
+  %ctpop2 = call i16 @llvm.ctpop.i16(i16 %x)
+  %res2 = icmp ugt i16 %ctpop2, 8
+  ret i1 %res2
+}
+
 declare i16 @llvm.ctlz.i16(i16, i1)
+declare i16 @llvm.cttz.i16(i16, i1)
+declare i16 @llvm.ctpop.i16(i16)
 declare i16 @llvm.abs.i16(i16, i1)
 declare void @llvm.assume(i1)
diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp
index 1cb358a26062ca5..e505af5d3275ef2 100644
--- a/llvm/unittests/IR/ConstantRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeTest.cpp
@@ -2438,6 +2438,26 @@ TEST_F(ConstantRangeTest, Ctlz) {
       });
 }
 
+TEST_F(ConstantRangeTest, Cttz) {
+  TestUnaryOpExhaustive(
+      [](const ConstantRange &CR) { return CR.cttz(); },
+      [](const APInt &N) { return APInt(N.getBitWidth(), N.countr_zero()); });
+
+  TestUnaryOpExhaustive(
+      [](const ConstantRange &CR) { return CR.cttz(/*ZeroIsPoison=*/true); },
+      [](const APInt &N) -> std::optional<APInt> {
+        if (N.isZero())
+          return std::nullopt;
+        return APInt(N.getBitWidth(), N.countr_zero());
+      });
+}
+
+TEST_F(ConstantRangeTest, Ctpop) {
+  TestUnaryOpExhaustive(
+      [](const ConstantRange &CR) { return CR.ctpop(); },
+      [](const APInt &N) { return APInt(N.getBitWidth(), N.popcount()); });
+}
+
 TEST_F(ConstantRangeTest, castOps) {
   ConstantRange A(APInt(16, 66), APInt(16, 128));
   ConstantRange FpToI8 = A.castOp(Instruction::FPToSI, 8);

@nikic
Copy link
Contributor

nikic commented Oct 1, 2023

Please split this into separate patches for cttz and ctpop.

@dtcxzyw dtcxzyw requested a review from nikic as a code owner October 5, 2023 13:04
@dtcxzyw dtcxzyw changed the title [ConstantRange] Handle Intrinsic::cttz and Intrinsic::ctpop [ConstantRange] Handle Intrinsic::cttz Oct 5, 2023
@dtcxzyw
Copy link
Member Author

dtcxzyw commented Oct 22, 2023

Ping.

@goldsteinn
Copy link
Contributor

LGTM.

@dtcxzyw
Copy link
Member Author

dtcxzyw commented Oct 23, 2023

@nikic Any comments about this PR?

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think most of the comments from the ctpop review also apply to this one.

@dtcxzyw dtcxzyw force-pushed the constrange-cttz-ctpop branch 2 times, most recently from a2ea2e9 to 8860c06 Compare November 4, 2023 15:40
Copy link

github-actions bot commented Nov 4, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dtcxzyw dtcxzyw merged commit 6f1d3b9 into llvm:main Nov 6, 2023
3 checks passed
@dtcxzyw dtcxzyw deleted the constrange-cttz-ctpop branch November 6, 2023 11:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants