Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ValueTracking] improve isKnownNonZero precision for smax #88170

Closed
wants to merge 2 commits into from

Conversation

goldsteinn
Copy link
Contributor

@goldsteinn goldsteinn commented Apr 9, 2024

  • [ValueTracking] Add tests for improving isKnownNonZero of smax; NFC
  • [ValueTracking] improve isKnownNonZero precision for smax

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 9, 2024

@llvm/pr-subscribers-llvm-analysis

Author: None (goldsteinn)

Changes
  • [ValueTracking] Expand isKnown{Negative,Positive} APIs; NFC
  • [ValueTracking] Add tests for improving isKnownNonZero of smax; NFC
  • [ValueTracking] improve isKnownNonZero precision for smax

Full diff: https://github.com/llvm/llvm-project/pull/88170.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Analysis/ValueTracking.h (+10)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+47-19)
  • (modified) llvm/test/Transforms/InstSimplify/known-non-zero.ll (+11)
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 3970efba18cc8c..7287a8fb122bbb 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -145,11 +145,21 @@ bool isKnownNonNegative(const Value *V, const SimplifyQuery &SQ,
 bool isKnownPositive(const Value *V, const SimplifyQuery &SQ,
                      unsigned Depth = 0);
 
+/// Returns true if the given value is known be positive (i.e. non-negative
+/// and non-zero) for DemandedElts.
+bool isKnownPositive(const Value *V, const APInt &DemandedElts,
+                     const SimplifyQuery &SQ, unsigned Depth = 0);
+
 /// Returns true if the given value is known be negative (i.e. non-positive
 /// and non-zero).
 bool isKnownNegative(const Value *V, const SimplifyQuery &DL,
                      unsigned Depth = 0);
 
+/// Returns true if the given value is known be negative (i.e. non-positive
+/// and non-zero) for DemandedElts.
+bool isKnownNegative(const Value *V, const APInt &DemandedElts,
+                     const SimplifyQuery &DL, unsigned Depth = 0);
+
 /// Return true if the given values are known to be non-equal when defined.
 /// Supports scalar integer types only.
 bool isKnownNonEqual(const Value *V1, const Value *V2, const DataLayout &DL,
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index ca48cfe7738154..187d781c59e072 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -289,21 +289,52 @@ bool llvm::isKnownNonNegative(const Value *V, const SimplifyQuery &SQ,
   return computeKnownBits(V, Depth, SQ).isNonNegative();
 }
 
-bool llvm::isKnownPositive(const Value *V, const SimplifyQuery &SQ,
-                           unsigned Depth) {
+static bool isKnownPositive(const Value *V, const APInt &DemandedElts,
+                            KnownBits &Known, const SimplifyQuery &SQ,
+                            unsigned Depth) {
   if (auto *CI = dyn_cast<ConstantInt>(V))
     return CI->getValue().isStrictlyPositive();
 
   // If `isKnownNonNegative` ever becomes more sophisticated, make sure to keep
   // this updated.
-  KnownBits Known = computeKnownBits(V, Depth, SQ);
+  Known = computeKnownBits(V, DemandedElts, Depth, SQ);
   return Known.isNonNegative() &&
-         (Known.isNonZero() || ::isKnownNonZero(V, Depth, SQ));
+         (Known.isNonZero() || ::isKnownNonZero(V, DemandedElts, Depth, SQ));
+}
+
+bool llvm::isKnownPositive(const Value *V, const APInt &DemandedElts,
+                           const SimplifyQuery &SQ, unsigned Depth) {
+  KnownBits Known(getBitWidth(V->getType(), SQ.DL));
+  return ::isKnownPositive(V, DemandedElts, Known, SQ, Depth);
+}
+
+bool llvm::isKnownPositive(const Value *V, const SimplifyQuery &SQ,
+                           unsigned Depth) {
+  auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
+  APInt DemandedElts =
+      FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
+  return isKnownPositive(V, DemandedElts, SQ, Depth);
+}
+
+static bool isKnownNegative(const Value *V, const APInt &DemandedElts,
+                            KnownBits &Known, const SimplifyQuery &SQ,
+                            unsigned Depth) {
+  Known = computeKnownBits(V, DemandedElts, Depth, SQ);
+  return Known.isNegative();
+}
+
+bool llvm::isKnownNegative(const Value *V, const APInt &DemandedElts,
+                           const SimplifyQuery &SQ, unsigned Depth) {
+  KnownBits Known(getBitWidth(V->getType(), SQ.DL));
+  return ::isKnownNegative(V, DemandedElts, Known, SQ, Depth);
 }
 
 bool llvm::isKnownNegative(const Value *V, const SimplifyQuery &SQ,
                            unsigned Depth) {
-  return computeKnownBits(V, Depth, SQ).isNegative();
+  auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
+  APInt DemandedElts =
+      FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
+  return isKnownNegative(V, DemandedElts, SQ, Depth);
 }
 
 static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
@@ -2830,21 +2861,18 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
                isKnownNonZero(II->getArgOperand(0), DemandedElts, Depth, Q);
       case Intrinsic::smin:
       case Intrinsic::smax: {
-        auto KnownOpImpliesNonZero = [&](const KnownBits &K) {
-          return II->getIntrinsicID() == Intrinsic::smin
-                     ? K.isNegative()
-                     : K.isStrictlyPositive();
+        bool AllNonZero = true;
+        auto KnownOpImpliesNonZero = [&](const Value *Op) {
+          KnownBits TmpKnown(getBitWidth(Op->getType(), Q.DL));
+          bool Ret =
+              II->getIntrinsicID() == Intrinsic::smin
+                  ? ::isKnownNegative(Op, DemandedElts, TmpKnown, Q, Depth)
+                  : ::isKnownPositive(Op, DemandedElts, TmpKnown, Q, Depth);
+          AllNonZero &= TmpKnown.isNonZero();
+          return Ret;
         };
-        KnownBits XKnown =
-            computeKnownBits(II->getArgOperand(0), DemandedElts, Depth, Q);
-        if (KnownOpImpliesNonZero(XKnown))
-          return true;
-        KnownBits YKnown =
-            computeKnownBits(II->getArgOperand(1), DemandedElts, Depth, Q);
-        if (KnownOpImpliesNonZero(YKnown))
-          return true;
-
-        if (XKnown.isNonZero() && YKnown.isNonZero())
+        if (KnownOpImpliesNonZero(II->getArgOperand(0)) ||
+            KnownOpImpliesNonZero(II->getArgOperand(1)) || AllNonZero)
           return true;
       }
         [[fallthrough]];
diff --git a/llvm/test/Transforms/InstSimplify/known-non-zero.ll b/llvm/test/Transforms/InstSimplify/known-non-zero.ll
index b647f11af4461d..51f80f62c2f34c 100644
--- a/llvm/test/Transforms/InstSimplify/known-non-zero.ll
+++ b/llvm/test/Transforms/InstSimplify/known-non-zero.ll
@@ -166,3 +166,14 @@ A:
 B:
   ret i1 0
 }
+
+define i1 @smax_non_zero(i8 %xx, i8 %y) {
+; CHECK-LABEL: @smax_non_zero(
+; CHECK-NEXT:    ret i1 false
+;
+  %x0 = and i8 %xx, 63
+  %x = add i8 %x0, 1
+  %v = call i8 @llvm.smax.i8(i8 %x, i8 %y)
+  %r = icmp eq i8 %v, 0
+  ret i1 %r
+}

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 9, 2024

@llvm/pr-subscribers-llvm-transforms

Author: None (goldsteinn)

Changes
  • [ValueTracking] Expand isKnown{Negative,Positive} APIs; NFC
  • [ValueTracking] Add tests for improving isKnownNonZero of smax; NFC
  • [ValueTracking] improve isKnownNonZero precision for smax

Full diff: https://github.com/llvm/llvm-project/pull/88170.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Analysis/ValueTracking.h (+10)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+47-19)
  • (modified) llvm/test/Transforms/InstSimplify/known-non-zero.ll (+11)
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 3970efba18cc8c..7287a8fb122bbb 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -145,11 +145,21 @@ bool isKnownNonNegative(const Value *V, const SimplifyQuery &SQ,
 bool isKnownPositive(const Value *V, const SimplifyQuery &SQ,
                      unsigned Depth = 0);
 
+/// Returns true if the given value is known be positive (i.e. non-negative
+/// and non-zero) for DemandedElts.
+bool isKnownPositive(const Value *V, const APInt &DemandedElts,
+                     const SimplifyQuery &SQ, unsigned Depth = 0);
+
 /// Returns true if the given value is known be negative (i.e. non-positive
 /// and non-zero).
 bool isKnownNegative(const Value *V, const SimplifyQuery &DL,
                      unsigned Depth = 0);
 
+/// Returns true if the given value is known be negative (i.e. non-positive
+/// and non-zero) for DemandedElts.
+bool isKnownNegative(const Value *V, const APInt &DemandedElts,
+                     const SimplifyQuery &DL, unsigned Depth = 0);
+
 /// Return true if the given values are known to be non-equal when defined.
 /// Supports scalar integer types only.
 bool isKnownNonEqual(const Value *V1, const Value *V2, const DataLayout &DL,
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index ca48cfe7738154..187d781c59e072 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -289,21 +289,52 @@ bool llvm::isKnownNonNegative(const Value *V, const SimplifyQuery &SQ,
   return computeKnownBits(V, Depth, SQ).isNonNegative();
 }
 
-bool llvm::isKnownPositive(const Value *V, const SimplifyQuery &SQ,
-                           unsigned Depth) {
+static bool isKnownPositive(const Value *V, const APInt &DemandedElts,
+                            KnownBits &Known, const SimplifyQuery &SQ,
+                            unsigned Depth) {
   if (auto *CI = dyn_cast<ConstantInt>(V))
     return CI->getValue().isStrictlyPositive();
 
   // If `isKnownNonNegative` ever becomes more sophisticated, make sure to keep
   // this updated.
-  KnownBits Known = computeKnownBits(V, Depth, SQ);
+  Known = computeKnownBits(V, DemandedElts, Depth, SQ);
   return Known.isNonNegative() &&
-         (Known.isNonZero() || ::isKnownNonZero(V, Depth, SQ));
+         (Known.isNonZero() || ::isKnownNonZero(V, DemandedElts, Depth, SQ));
+}
+
+bool llvm::isKnownPositive(const Value *V, const APInt &DemandedElts,
+                           const SimplifyQuery &SQ, unsigned Depth) {
+  KnownBits Known(getBitWidth(V->getType(), SQ.DL));
+  return ::isKnownPositive(V, DemandedElts, Known, SQ, Depth);
+}
+
+bool llvm::isKnownPositive(const Value *V, const SimplifyQuery &SQ,
+                           unsigned Depth) {
+  auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
+  APInt DemandedElts =
+      FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
+  return isKnownPositive(V, DemandedElts, SQ, Depth);
+}
+
+static bool isKnownNegative(const Value *V, const APInt &DemandedElts,
+                            KnownBits &Known, const SimplifyQuery &SQ,
+                            unsigned Depth) {
+  Known = computeKnownBits(V, DemandedElts, Depth, SQ);
+  return Known.isNegative();
+}
+
+bool llvm::isKnownNegative(const Value *V, const APInt &DemandedElts,
+                           const SimplifyQuery &SQ, unsigned Depth) {
+  KnownBits Known(getBitWidth(V->getType(), SQ.DL));
+  return ::isKnownNegative(V, DemandedElts, Known, SQ, Depth);
 }
 
 bool llvm::isKnownNegative(const Value *V, const SimplifyQuery &SQ,
                            unsigned Depth) {
-  return computeKnownBits(V, Depth, SQ).isNegative();
+  auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
+  APInt DemandedElts =
+      FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
+  return isKnownNegative(V, DemandedElts, SQ, Depth);
 }
 
 static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
@@ -2830,21 +2861,18 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
                isKnownNonZero(II->getArgOperand(0), DemandedElts, Depth, Q);
       case Intrinsic::smin:
       case Intrinsic::smax: {
-        auto KnownOpImpliesNonZero = [&](const KnownBits &K) {
-          return II->getIntrinsicID() == Intrinsic::smin
-                     ? K.isNegative()
-                     : K.isStrictlyPositive();
+        bool AllNonZero = true;
+        auto KnownOpImpliesNonZero = [&](const Value *Op) {
+          KnownBits TmpKnown(getBitWidth(Op->getType(), Q.DL));
+          bool Ret =
+              II->getIntrinsicID() == Intrinsic::smin
+                  ? ::isKnownNegative(Op, DemandedElts, TmpKnown, Q, Depth)
+                  : ::isKnownPositive(Op, DemandedElts, TmpKnown, Q, Depth);
+          AllNonZero &= TmpKnown.isNonZero();
+          return Ret;
         };
-        KnownBits XKnown =
-            computeKnownBits(II->getArgOperand(0), DemandedElts, Depth, Q);
-        if (KnownOpImpliesNonZero(XKnown))
-          return true;
-        KnownBits YKnown =
-            computeKnownBits(II->getArgOperand(1), DemandedElts, Depth, Q);
-        if (KnownOpImpliesNonZero(YKnown))
-          return true;
-
-        if (XKnown.isNonZero() && YKnown.isNonZero())
+        if (KnownOpImpliesNonZero(II->getArgOperand(0)) ||
+            KnownOpImpliesNonZero(II->getArgOperand(1)) || AllNonZero)
           return true;
       }
         [[fallthrough]];
diff --git a/llvm/test/Transforms/InstSimplify/known-non-zero.ll b/llvm/test/Transforms/InstSimplify/known-non-zero.ll
index b647f11af4461d..51f80f62c2f34c 100644
--- a/llvm/test/Transforms/InstSimplify/known-non-zero.ll
+++ b/llvm/test/Transforms/InstSimplify/known-non-zero.ll
@@ -166,3 +166,14 @@ A:
 B:
   ret i1 0
 }
+
+define i1 @smax_non_zero(i8 %xx, i8 %y) {
+; CHECK-LABEL: @smax_non_zero(
+; CHECK-NEXT:    ret i1 false
+;
+  %x0 = and i8 %xx, 63
+  %x = add i8 %x0, 1
+  %v = call i8 @llvm.smax.i8(i8 %x, i8 %y)
+  %r = icmp eq i8 %v, 0
+  ret i1 %r
+}

@goldsteinn goldsteinn changed the title perf/goldsteinn/improve smax [ValueTracking] improve isKnownNonZero precision for smax Apr 9, 2024
@goldsteinn goldsteinn requested a review from dtcxzyw April 9, 2024 18:27
dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Apr 9, 2024
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is too much magic in these APIs now. isKnownPositive() has a Known out parameter now, but it does not actually get initialized in all cases (if the value is actually a constant).

I think this would end up being cleaner overall if you just spelled out that extra isKnownNonZero call in the smax implementation directly.

llvm/lib/Analysis/ValueTracking.cpp Outdated Show resolved Hide resolved
llvm/lib/Analysis/ValueTracking.cpp Outdated Show resolved Hide resolved
llvm/lib/Analysis/ValueTracking.cpp Outdated Show resolved Hide resolved
Instead of relying on known-bits for strictly positive, use the
`isKnownPositive` API. This will use `isKnownNonZero` which is more
accurate.
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

// If either arg is strictly positive the result is non-zero. Otherwise
// the result is non-zero if both ops are non-zero.
auto IsNonZero = [&](Value *Op, std::optional<bool> &OpNonZero,
KnownBits OpKnown) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
KnownBits OpKnown) {
const KnownBits &OpKnown) {

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants