Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KnownBits] Make nuw and nsw support in computeForAddSub optimal #83382

Closed
wants to merge 2 commits into from

Conversation

goldsteinn
Copy link
Contributor

@goldsteinn goldsteinn commented Feb 29, 2024

  • [KnownBits] Add API for nuw flag in computeForAddSub; NFC
  • [KnownBits] Make nuw and nsw support in computeForAddSub optimal

@goldsteinn goldsteinn changed the title goldsteinn/knownbits add sub [Knowbits] Make nuw and nsw support in computeForAddSub optimal Feb 29, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 29, 2024

@llvm/pr-subscribers-llvm-selectiondag
@llvm/pr-subscribers-llvm-support
@llvm/pr-subscribers-backend-amdgpu
@llvm/pr-subscribers-backend-aarch64
@llvm/pr-subscribers-llvm-transforms
@llvm/pr-subscribers-backend-arm

@llvm/pr-subscribers-llvm-globalisel

Author: None (goldsteinn)

Changes
  • [KnownBits] Add API for nuw flag in computeForAddSub; NFC
  • [Knowbits] Make nuw and nsw support in computeForAddSub optimal

Patch is 35.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/83382.diff

17 Files Affected:

  • (modified) llvm/include/llvm/Support/KnownBits.h (+2-2)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+25-21)
  • (modified) llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp (+5-5)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+3-2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+3-3)
  • (modified) llvm/lib/Support/KnownBits.cpp (+167-16)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp (+2-1)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp (+1-1)
  • (modified) llvm/lib/Target/ARM/ARMISelLowering.cpp (+2-1)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp (+10-4)
  • (modified) llvm/test/CodeGen/AArch64/sve-cmp-folds.ll (+6-3)
  • (modified) llvm/test/CodeGen/AArch64/sve-extract-element.ll (+5-3)
  • (modified) llvm/test/CodeGen/AMDGPU/ds-sub-offset.ll (+17-22)
  • (modified) llvm/test/Transforms/InstCombine/fold-log2-ceil-idiom.ll (+1-1)
  • (modified) llvm/test/Transforms/InstCombine/icmp-sub.ll (+2-3)
  • (modified) llvm/test/Transforms/InstCombine/sub.ll (+1-1)
  • (modified) llvm/unittests/Support/KnownBitsTest.cpp (+54-16)
diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index fb034e0b9e3baf..4e9eb0c10a5628 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -329,8 +329,8 @@ struct KnownBits {
       const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry);
 
   /// Compute known bits resulting from adding LHS and RHS.
-  static KnownBits computeForAddSub(bool Add, bool NSW, const KnownBits &LHS,
-                                    KnownBits RHS);
+  static KnownBits computeForAddSub(bool Add, bool NSW, bool NUW,
+                                    const KnownBits &LHS, KnownBits RHS);
 
   /// Compute known bits results from subtracting RHS from LHS with 1-bit
   /// Borrow.
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index e591ac504e9f05..c220674c5f21d2 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -350,18 +350,19 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
 }
 
 static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
-                                   bool NSW, const APInt &DemandedElts,
+                                   bool NSW, bool NUW,
+                                   const APInt &DemandedElts,
                                    KnownBits &KnownOut, KnownBits &Known2,
                                    unsigned Depth, const SimplifyQuery &Q) {
   computeKnownBits(Op1, DemandedElts, KnownOut, Depth + 1, Q);
 
   // If one operand is unknown and we have no nowrap information,
   // the result will be unknown independently of the second operand.
-  if (KnownOut.isUnknown() && !NSW)
+  if (KnownOut.isUnknown() && !NSW && !NUW)
     return;
 
   computeKnownBits(Op0, DemandedElts, Known2, Depth + 1, Q);
-  KnownOut = KnownBits::computeForAddSub(Add, NSW, Known2, KnownOut);
+  KnownOut = KnownBits::computeForAddSub(Add, NSW, NUW, Known2, KnownOut);
 }
 
 static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
@@ -1145,13 +1146,15 @@ static void computeKnownBitsFromOperator(const Operator *I,
   }
   case Instruction::Sub: {
     bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
-    computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW,
+    bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
+    computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, NUW,
                            DemandedElts, Known, Known2, Depth, Q);
     break;
   }
   case Instruction::Add: {
     bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
-    computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW,
+    bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
+    computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, NUW,
                            DemandedElts, Known, Known2, Depth, Q);
     break;
   }
@@ -1245,12 +1248,12 @@ static void computeKnownBitsFromOperator(const Operator *I,
       // Note that inbounds does *not* guarantee nsw for the addition, as only
       // the offset is signed, while the base address is unsigned.
       Known = KnownBits::computeForAddSub(
-          /*Add=*/true, /*NSW=*/false, Known, IndexBits);
+          /*Add=*/true, /*NSW=*/false, /* NUW=*/false, Known, IndexBits);
     }
     if (!Known.isUnknown() && !AccConstIndices.isZero()) {
       KnownBits Index = KnownBits::makeConstant(AccConstIndices);
       Known = KnownBits::computeForAddSub(
-          /*Add=*/true, /*NSW=*/false, Known, Index);
+          /*Add=*/true, /*NSW=*/false, /* NUW=*/false, Known, Index);
     }
     break;
   }
@@ -1689,15 +1692,15 @@ static void computeKnownBitsFromOperator(const Operator *I,
         default: break;
         case Intrinsic::uadd_with_overflow:
         case Intrinsic::sadd_with_overflow:
-          computeKnownBitsAddSub(true, II->getArgOperand(0),
-                                 II->getArgOperand(1), false, DemandedElts,
-                                 Known, Known2, Depth, Q);
+          computeKnownBitsAddSub(
+              true, II->getArgOperand(0), II->getArgOperand(1), /*NSW*/ false,
+              /* NUW*/ false, DemandedElts, Known, Known2, Depth, Q);
           break;
         case Intrinsic::usub_with_overflow:
         case Intrinsic::ssub_with_overflow:
-          computeKnownBitsAddSub(false, II->getArgOperand(0),
-                                 II->getArgOperand(1), false, DemandedElts,
-                                 Known, Known2, Depth, Q);
+          computeKnownBitsAddSub(
+              false, II->getArgOperand(0), II->getArgOperand(1), /*NSW*/ false,
+              /* NUW*/ false, DemandedElts, Known, Known2, Depth, Q);
           break;
         case Intrinsic::umul_with_overflow:
         case Intrinsic::smul_with_overflow:
@@ -2318,7 +2321,11 @@ static bool isNonZeroRecurrence(const PHINode *PN) {
 
 static bool isNonZeroAdd(const APInt &DemandedElts, unsigned Depth,
                          const SimplifyQuery &Q, unsigned BitWidth, Value *X,
-                         Value *Y, bool NSW) {
+                         Value *Y, bool NSW, bool NUW) {
+  if (NUW)
+    return isKnownNonZero(Y, DemandedElts, Depth, Q) ||
+           isKnownNonZero(X, DemandedElts, Depth, Q);
+
   KnownBits XKnown = computeKnownBits(X, DemandedElts, Depth, Q);
   KnownBits YKnown = computeKnownBits(Y, DemandedElts, Depth, Q);
 
@@ -2351,7 +2358,7 @@ static bool isNonZeroAdd(const APInt &DemandedElts, unsigned Depth,
       isKnownToBeAPowerOfTwo(X, /*OrZero*/ false, Depth, Q))
     return true;
 
-  return KnownBits::computeForAddSub(/*Add*/ true, NSW, XKnown, YKnown)
+  return KnownBits::computeForAddSub(/*Add*/ true, NSW, NUW, XKnown, YKnown)
       .isNonZero();
 }
 
@@ -2556,12 +2563,9 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
     // If Add has nuw wrap flag, then if either X or Y is non-zero the result is
     // non-zero.
     auto *BO = cast<OverflowingBinaryOperator>(I);
-    if (Q.IIQ.hasNoUnsignedWrap(BO))
-      return isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q) ||
-             isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q);
-
     return isNonZeroAdd(DemandedElts, Depth, Q, BitWidth, I->getOperand(0),
-                        I->getOperand(1), Q.IIQ.hasNoSignedWrap(BO));
+                        I->getOperand(1), Q.IIQ.hasNoSignedWrap(BO),
+                        Q.IIQ.hasNoUnsignedWrap(BO));
   }
   case Instruction::Mul: {
     // If X and Y are non-zero then so is X * Y as long as the multiplication
@@ -2716,7 +2720,7 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
       case Intrinsic::sadd_sat:
         return isNonZeroAdd(DemandedElts, Depth, Q, BitWidth,
                             II->getArgOperand(0), II->getArgOperand(1),
-                            /*NSW*/ true);
+                            /*NSW*/ true, /* NUW*/ false);
       case Intrinsic::umax:
       case Intrinsic::uadd_sat:
         return isKnownNonZero(II->getArgOperand(1), DemandedElts, Depth, Q) ||
diff --git a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
index ea8c20cdcd45d6..83c04612d2d43e 100644
--- a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
@@ -269,8 +269,8 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
                          Depth + 1);
     computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
                          Depth + 1);
-    Known = KnownBits::computeForAddSub(/*Add*/ false, /*NSW*/ false, Known,
-                                        Known2);
+    Known = KnownBits::computeForAddSub(/*Add*/ false, /*NSW*/ false,
+                                        /* NUW*/ false, Known, Known2);
     break;
   }
   case TargetOpcode::G_XOR: {
@@ -296,8 +296,8 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
                          Depth + 1);
     computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
                          Depth + 1);
-    Known =
-        KnownBits::computeForAddSub(/*Add*/ true, /*NSW*/ false, Known, Known2);
+    Known = KnownBits::computeForAddSub(/*Add*/ true, /*NSW*/ false,
+                                        /* NUW*/ false, Known, Known2);
     break;
   }
   case TargetOpcode::G_AND: {
@@ -564,7 +564,7 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
     // right.
     KnownBits ExtKnown = KnownBits::makeConstant(APInt(BitWidth, BitWidth));
     KnownBits ShiftKnown = KnownBits::computeForAddSub(
-        /*Add*/ false, /*NSW*/ false, ExtKnown, WidthKnown);
+        /*Add*/ false, /*NSW*/ false, /* NUW*/ false, ExtKnown, WidthKnown);
     Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown);
     break;
   }
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index e150f27240d7f0..dbcdb722b741a7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3753,8 +3753,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
     SDNodeFlags Flags = Op.getNode()->getFlags();
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
-    Known = KnownBits::computeForAddSub(Op.getOpcode() == ISD::ADD,
-                                        Flags.hasNoSignedWrap(), Known, Known2);
+    Known = KnownBits::computeForAddSub(
+        Op.getOpcode() == ISD::ADD, Flags.hasNoSignedWrap(),
+        Flags.hasNoUnsignedWrap(), Known, Known2);
     break;
   }
   case ISD::USUBO:
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 6970b230837fb9..a639cba5e35a80 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -2876,9 +2876,9 @@ bool TargetLowering::SimplifyDemandedBits(
     if (Op.getOpcode() == ISD::MUL) {
       Known = KnownBits::mul(KnownOp0, KnownOp1);
     } else { // Op.getOpcode() is either ISD::ADD or ISD::SUB.
-      Known = KnownBits::computeForAddSub(Op.getOpcode() == ISD::ADD,
-                                          Flags.hasNoSignedWrap(), KnownOp0,
-                                          KnownOp1);
+      Known = KnownBits::computeForAddSub(
+          Op.getOpcode() == ISD::ADD, Flags.hasNoSignedWrap(),
+          Flags.hasNoUnsignedWrap(), KnownOp0, KnownOp1);
     }
     break;
   }
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 770e4051ca3ffa..b575f97094891f 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -54,7 +54,7 @@ KnownBits KnownBits::computeForAddCarry(
       LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
 }
 
-KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
+KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW,
                                       const KnownBits &LHS, KnownBits RHS) {
   KnownBits KnownOut;
   if (Add) {
@@ -63,23 +63,173 @@ KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
         LHS, RHS, /*CarryZero*/true, /*CarryOne*/false);
   } else {
     // Sum = LHS + ~RHS + 1
-    std::swap(RHS.Zero, RHS.One);
-    KnownOut = ::computeForAddCarry(
-        LHS, RHS, /*CarryZero*/false, /*CarryOne*/true);
+    KnownBits NotRHS = RHS;
+    std::swap(NotRHS.Zero, NotRHS.One);
+    KnownOut = ::computeForAddCarry(LHS, NotRHS, /*CarryZero*/ false,
+                                    /*CarryOne*/ true);
+  }
+  if (!NSW && !NUW)
+    return KnownOut;
+
+  // We truncate out the signbit during nsw handling so just handle this special
+  // case to avoid dealing with it later.
+  if (LHS.getBitWidth() == 1) {
+    return LHS | RHS;
   }
 
-  // Are we still trying to solve for the sign bit?
-  if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) {
+  auto GetMinMaxVal = [Add](bool ForNSW, bool ForMax, const KnownBits &L,
+                            const KnownBits &R, bool &OV) {
+    APInt LVal = ForMax ? L.getMaxValue() : L.getMinValue();
+    APInt RVal = Add == ForMax ? R.getMaxValue() : R.getMinValue();
+
+    if (ForNSW) {
+      LVal = LVal.trunc(LVal.getBitWidth() - 1);
+      RVal = RVal.trunc(RVal.getBitWidth() - 1);
+    }
+    APInt Res = Add ? LVal.uadd_ov(RVal, OV) : LVal.usub_ov(RVal, OV);
+    if (ForNSW)
+      Res = Res.sext(Res.getBitWidth() + 1);
+    return Res;
+  };
+
+  auto GetMaxVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
+                                   const KnownBits &R, bool &OV) {
+    return GetMinMaxVal(ForNSW, /*ForMax*/ true, L, R, OV);
+  };
+
+  auto GetMinVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
+                                   const KnownBits &R, bool &OV) {
+    return GetMinMaxVal(ForNSW, /*ForMax*/ false, L, R, OV);
+  };
+
+  std::optional<bool> Negative;
+  bool Poison = false;
+  // Handle add/sub given nsw and/or nuw.
+  //
+  // Possible TODO: Add/Sub implementations mirror one another in many ways.
+  // They could probably be compressed into a single implementation of roughly
+  // half the total LOC. Leaving seperate for now to increase clarity.
+  // NB: We handle NSW by truncating sign bits then deducing bits based on
+  // the known sign result.
+  if (Add) {
     if (NSW) {
-      // Adding two non-negative numbers, or subtracting a negative number from
-      // a non-negative one, can't wrap into negative.
-      if (LHS.isNonNegative() && RHS.isNonNegative())
-        KnownOut.makeNonNegative();
-      // Adding two negative numbers, or subtracting a non-negative number from
-      // a negative one, can't wrap into non-negative.
-      else if (LHS.isNegative() && RHS.isNegative())
-        KnownOut.makeNegative();
+      bool OverflowMax, OverflowMin;
+      APInt MaxVal = GetMaxVal(/*ForNSW*/ true, LHS, RHS, OverflowMax);
+      APInt MinVal = GetMinVal(/*ForNSW*/ true, LHS, RHS, OverflowMin);
+
+      if (NUW || (LHS.isNonNegative() && RHS.isNonNegative())) {
+        // (add nuw) or (add nsw PosX, PosY)
+
+        // None of the adds can end up overflowing, so min consecutive highbits
+        // in minimum possible of X + Y must all remain set.
+        KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+
+        // NSW and Positive arguments leads to positive result.
+        if (LHS.isNonNegative() && RHS.isNonNegative())
+          Negative = false;
+        else
+          KnownOut.One.clearSignBit();
+
+        Poison = OverflowMin;
+      } else if (LHS.isNegative() && RHS.isNegative()) {
+        // (add nsw NegX, NegY)
+
+        // We need to re-overflow the signbit, so we are looking for sequence of
+        // 0s from consecutive overflows.
+        KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
+        Negative = true;
+        Poison = !OverflowMax;
+      } else if (LHS.isNonNegative() || RHS.isNonNegative()) {
+        // (add nsw PosX, ?Y)
+
+        // If the minimal possible of X + Y overflows the signbit, then Y must
+        // have been signed (which will cause unsigned overflow otherwise nsw
+        // will be violated) leading to unsigned result.
+        if (OverflowMin)
+          Negative = false;
+      } else if (LHS.isNegative() || RHS.isNegative()) {
+        // (add nsw NegX, ?Y)
+
+        // If the maximum possible of X + Y doesn't overflows the signbit, then
+        // Y must have been unsigned (otherwise nsw violated) so NegX + PosY w.o
+        // overflowing the signbit results in Negative.
+        if (!OverflowMax)
+          Negative = true;
+      }
     }
+    if (NUW) {
+        // (add nuw X, Y)
+      bool OverflowMax, OverflowMin;
+      APInt MaxVal = GetMaxVal(/*ForNSW*/ false, LHS, RHS, OverflowMax);
+      APInt MinVal = GetMinVal(/*ForNSW*/ false, LHS, RHS, OverflowMin);
+      // Same as (add nsw PosX, PosY), basically since we can't overflow, the
+      // high bits of minimum possible X + Y must remain set.
+      KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+      Poison = OverflowMin;
+    }
+  } else {
+    if (NSW) {
+      bool OverflowMax, OverflowMin;
+      APInt MaxVal = GetMaxVal(/*ForNSW*/ true, LHS, RHS, OverflowMax);
+      APInt MinVal = GetMinVal(/*ForNSW*/ true, LHS, RHS, OverflowMin);
+      if (NUW || (LHS.isNegative() && RHS.isNonNegative())) {
+        // (sub nuw) or (sub nsw NegX, PosY)
+
+        // None of the subs can overflow at any point, so any common high bits
+        // will subtract away and result in zeros.
+        KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
+        if (LHS.isNegative() && RHS.isNonNegative())
+          Negative = true;
+        else
+          KnownOut.Zero.clearSignBit();
+
+        Poison = OverflowMax;
+      } else if (LHS.isNonNegative() && RHS.isNegative()) {
+        // (sub nsw PosX, NegY)
+        Negative = false;
+
+        // Opposite case of above, we must "re-overflow" the signbit, so minimal
+        // set of high bits will be fixed.
+        KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+        Poison = !OverflowMin;
+      } else if (LHS.isNegative() || RHS.isNonNegative()) {
+        // (sub nsw NegX/?X, ?Y/PosY)
+        if (OverflowMax)
+          Negative = true;
+      } else if (LHS.isNonNegative() || RHS.isNegative()) {
+        // (sub nsw PosX/?X, ?Y/NegY)
+        if (!OverflowMin)
+          Negative = false;
+      }
+    }
+    if (NUW) {
+      // (sub nuw X, Y)
+      bool OverflowMax, OverflowMin;
+      APInt MaxVal = GetMaxVal(/*ForNSW*/ false, LHS, RHS, OverflowMax);
+      APInt MinVal = GetMinVal(/*ForNSW*/ false, LHS, RHS, OverflowMin);
+
+      // Basically all common high bits between X/Y will cancel out as leading
+      // zeros.
+      KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
+      Poison = OverflowMax;
+    }
+  }
+
+  // Handle any proven sign bit.
+  if (Negative.has_value()) {
+    KnownOut.One.clearSignBit();
+    KnownOut.Zero.clearSignBit();
+
+    if (*Negative)
+      KnownOut.makeNegative();
+    else
+      KnownOut.makeNonNegative();
+  }
+
+  // Just return 0 if the nsw/nuw is violated and we have poison.
+  if (Poison || KnownOut.hasConflict()) {
+    KnownOut.setAllZero();
+    return KnownOut;
   }
 
   return KnownOut;
@@ -443,7 +593,7 @@ KnownBits KnownBits::abs(bool IntMinIsPoison) const {
       Tmp.One.setBit(countMinTrailingZeros());
 
     KnownAbs = computeForAddSub(
-        /*Add*/ false, IntMinIsPoison,
+        /*Add*/ false, IntMinIsPoison, /*NUW*/ false,
         KnownBits::makeConstant(APInt(getBitWidth(), 0)), Tmp);
 
     // One more special case for IntMinIsPoison. If we don't know any ones other
@@ -489,7 +639,8 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
   assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs");
   // We don't see NSW even for sadd/ssub as we want to check if the result has
   // signed overflow.
-  KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW*/ false, LHS, RHS);
+  KnownBits Res =
+      KnownBits::computeForAddSub(Add, /*NSW*/ false, /*NUW*/ false, LHS, RHS);
   unsigned BitWidth = Res.getBitWidth();
   auto SignBitKnown = [&](const KnownBits &K) {
     return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
index 4896ae8bad9ef3..1e7cd2bab04123 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
@@ -1903,7 +1903,8 @@ bool AMDGPUDAGToDAGISel::checkFlatScratchSVSSwizzleBug(
   // voffset to (soffset + inst_offset).
   KnownBits VKnown = CurDAG->computeKnownBits(VAddr);
   KnownBits SKnown = KnownBits::computeForAddSub(
-      true, false, CurDAG->computeKnownBits(SAddr),
+      /*Add*/ true, /*NSW*/ false, /*NUW*/ false,
+      CurDAG->computeKnownBits(SAddr),
       KnownBits::makeConstant(APInt(32, ImmOffset)));
   uint64_t VMax = VKnown.getMaxValue().getZExtValue();
   uint64_t SMax = SKnown.getMaxValue().getZExtValue();
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUIn...
[truncated]

Copy link

github-actions bot commented Feb 29, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

llvm/lib/Support/KnownBits.cpp Outdated Show resolved Hide resolved
llvm/lib/Support/KnownBits.cpp Outdated Show resolved Hide resolved
llvm/lib/Support/KnownBits.cpp Outdated Show resolved Hide resolved
llvm/lib/Support/KnownBits.cpp Outdated Show resolved Hide resolved
llvm/lib/Support/KnownBits.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp Outdated Show resolved Hide resolved
llvm/lib/Analysis/ValueTracking.cpp Outdated Show resolved Hide resolved
llvm/lib/Analysis/ValueTracking.cpp Outdated Show resolved Hide resolved
llvm/lib/Analysis/ValueTracking.cpp Outdated Show resolved Hide resolved
llvm/lib/Analysis/ValueTracking.cpp Outdated Show resolved Hide resolved
@dtcxzyw dtcxzyw changed the title [Knowbits] Make nuw and nsw support in computeForAddSub optimal [KnownBits] Make nuw and nsw support in computeForAddSub optimal Feb 29, 2024
dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Feb 29, 2024
llvm/unittests/Support/KnownBitsTest.cpp Outdated Show resolved Hide resolved
llvm/unittests/Support/KnownBitsTest.cpp Outdated Show resolved Hide resolved
llvm/unittests/Support/KnownBitsTest.cpp Outdated Show resolved Hide resolved
llvm/unittests/Support/KnownBitsTest.cpp Outdated Show resolved Hide resolved
llvm/unittests/Support/KnownBitsTest.cpp Outdated Show resolved Hide resolved
llvm/unittests/Support/KnownBitsTest.cpp Outdated Show resolved Hide resolved
EXPECT_TRUE(KnownNSWComputed.Zero.isSubsetOf(KnownNSW.Zero));
EXPECT_TRUE(KnownNSWComputed.One.isSubsetOf(KnownNSW.One));
IsAdd, /*NSW*/ true, /*NUW*/ false, Known1, Known2);
if (!KnownNSW.hasConflict()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to add conflict tests? And why not keep the isSubsetOf checks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to add conflict tests?

Because there are some inputs that will always violate nsw/nuw and yield poison i.e:

Value of: isOptimal(KnownNSW, KnownNSWComputed, {Known1, Known2})
  Actual: false (Inputs = 1???00, 100001, Computed = 000000, Exact = !!!!!!)
Expected: true

We could also return a conflict in the poison cases, but because we assert(!Known.hasConflict())
in a lot of places that can potentially lead to crashes.

And why not keep the isSubsetOf checks?

isOptimal/isCorrect are helpers to the correctness check and have the added
benefit of printing the input/output on failure to make debugging easier.

; CHECK-NEXT: fcmeq p1.s, p0/z, z0.s, z1.s
; CHECK-NEXT: ptest p0, p1.b
; CHECK-NEXT: cset w0, lo
; CHECK-NEXT: mov x8, #-1 // =0xffffffffffffffff
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't speak AArch64 but this looks like an unfortunate regression, at least in terms of number of instructions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, ping @sdesmalen-arm, just to make you guys aware.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the following IR:

define i64 @foo_last() {
  %vscale = call i64 @llvm.vscale.i64()
  %shl2 = shl nuw nsw i64 %vscale, 2
  %idx = add nuw nsw i64 %shl2, -1
  ret i64 %idx
}

declare i64 @llvm.vscale.i64()

With the current patch this is incorrectly simplified to:

t9: ch,glue = CopyToReg t0, Register:i64 $x0, Constant:i64<-1>

I don't think there's anything AArch64-specific going on there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That simplification looks correct to me. The only way add -1 can the nuw is if the other operand is zero, so the result is -1. This probably wasn't supposed to have a nuw flag?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doh :) Yes I see it now. It should have been sub nuw nsw i64 shl2, 1 to avoid any wrapping (or indeed remove the nuw flag). I'll have a look to see if we generate this pattern with nuw flag anywhere. Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vscale can never be zero though (it should have a minimum value of 1). Should it be returning poison in that case?

Copy link
Contributor

@jayfoad jayfoad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall. I have ideas for simplifying it but that can wait.

llvm/lib/Support/KnownBits.cpp Outdated Show resolved Hide resolved
llvm/lib/Support/KnownBits.cpp Outdated Show resolved Hide resolved
@nikic
Copy link
Contributor

nikic commented Mar 1, 2024

@goldsteinn
Copy link
Contributor Author

This seems to add a good bit of compile-time overhead: http://llvm-compile-time-tracker.com/compare.php?from=2a67c28abe8cfde47c5058abbeb4b5ff9a393192&to=e383a7e50bf0b4303931ef119c856dcf53f3dd9e&stat=instructions%3Au

Wow, is the implementation just that in-efficient, or do you think its changing control flow decisions in some key places?

@nikic
Copy link
Contributor

nikic commented Mar 1, 2024

This seems to add a good bit of compile-time overhead: http://llvm-compile-time-tracker.com/compare.php?from=2a67c28abe8cfde47c5058abbeb4b5ff9a393192&to=e383a7e50bf0b4303931ef119c856dcf53f3dd9e&stat=instructions%3Au

Wow, is the implementation just that in-efficient, or do you think its changing control flow decisions in some key places?

Pretty sure this is either the implementation being slow, or the change at https://github.com/llvm/llvm-project/pull/83382/files#diff-4cc32f30c79b4c8161eac82916c70c1d56e75b0dd7d6e56bbb76f8b16e20b32dR361 (which might result in more KnownBits calculations now -- dunno whether that case is common or not).

@goldsteinn
Copy link
Contributor Author

This seems to add a good bit of compile-time overhead: http://llvm-compile-time-tracker.com/compare.php?from=2a67c28abe8cfde47c5058abbeb4b5ff9a393192&to=e383a7e50bf0b4303931ef119c856dcf53f3dd9e&stat=instructions%3Au

Wow, is the implementation just that in-efficient, or do you think its changing control flow decisions in some key places?

Pretty sure this is either the implementation being slow, or the change at https://github.com/llvm/llvm-project/pull/83382/files#diff-4cc32f30c79b4c8161eac82916c70c1d56e75b0dd7d6e56bbb76f8b16e20b32dR361 (which might result in more KnownBits calculations now -- dunno whether that case is common or not).

Okay, Ill investigate and see if I can make the impact a bit more reasonable.

@goldsteinn goldsteinn force-pushed the goldsteinn/knownbits-add-sub branch from 837ca86 to 235f1f4 Compare March 1, 2024 18:25
@goldsteinn
Copy link
Contributor Author

Value *Y, bool NSW, bool NUW) {
if (NUW)
return isKnownNonZero(Y, DemandedElts, Depth, Q) ||
isKnownNonZero(X, DemandedElts, Depth, Q);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate refactor patch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty NFC. The callsites where all checking this manually before handle. Since we want NUW in isAddNonZero seemed to just make sense. But no strong feelings

@goldsteinn goldsteinn force-pushed the goldsteinn/knownbits-add-sub branch from 04e24f0 to 99422c7 Compare March 2, 2024 18:22
Copy link
Contributor

@jayfoad jayfoad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Here's an idea for simplifying the code. I don't know if it's measurably faster. I don't really have time to finish it off at the moment.
jayfoad@c854484

llvm/lib/Support/KnownBits.cpp Outdated Show resolved Hide resolved
@goldsteinn
Copy link
Contributor Author

LGTM.

Here's an idea for simplifying the code. I don't know if it's measurably faster. I don't really have time to finish it off at the moment. jayfoad@c854484

Thanks going to push this, ill look into your impl later.

@goldsteinn
Copy link
Contributor Author

LGTM.
Here's an idea for simplifying the code. I don't know if it's measurably faster. I don't really have time to finish it off at the moment. jayfoad@c854484

Thanks going to push this, ill look into your impl later.

Your code is actually so much better, updating and reposting...

@goldsteinn goldsteinn force-pushed the goldsteinn/knownbits-add-sub branch from 99422c7 to 3da7dac Compare March 4, 2024 19:53
@goldsteinn goldsteinn requested a review from jayfoad March 4, 2024 19:54
dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Mar 4, 2024
@goldsteinn
Copy link
Contributor Author

LGTM.

Here's an idea for simplifying the code. I don't know if it's measurably faster. I don't really have time to finish it off at the moment. jayfoad@c854484

It is measurable faster :)
https://llvm-compile-time-tracker.com/compare.php?from=1f89a3cde2b407dc77e33ac3092133415ab151a9&to=9d52c0ffea77715ae70f1c66619504230d2da493&stat=instructions:u

Copy link
Contributor

@jayfoad jayfoad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one();
// If we have NSW as well, we also no we can't overflow the signbit so
// can start counting from 1 bit back.
KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the part of my refactoring that I was least happy about: for the nsw+nuw case we run this line AND line 94 AND all the nsw logic below, and they all seem to be necessary. I'm sure there's a simpler way of handling the nsw+nuw case but I couldn't find it.

@jayfoad
Copy link
Contributor

jayfoad commented Mar 5, 2024

Here's an idea for simplifying the code. I don't know if it's measurably faster. I don't really have time to finish it off at the moment. jayfoad@c854484

It is measurable faster :) https://llvm-compile-time-tracker.com/compare.php?from=1f89a3cde2b407dc77e33ac3092133415ab151a9&to=9d52c0ffea77715ae70f1c66619504230d2da493&stat=instructions:u

Nice! How does that compare to main?

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. But please file an issue to track the SVE regression.

// in minimum possible of X + Y must all remain set.
if (NSW) {
unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one();
// If we have NSW as well, we also no we can't overflow the signbit so
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// If we have NSW as well, we also no we can't overflow the signbit so
// If we have NSW as well, we also know we can't overflow the signbit so

// None of the subs can overflow at any point, so any common high bits
// will subtract away and result in zeros.
if (NSW) {
// If we have NSW as well, we also no we can't overflow the signbit so
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// If we have NSW as well, we also no we can't overflow the signbit so
// If we have NSW as well, we also know we can't overflow the signbit so

// can start counting from 1 bit back.
KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
}
KnownOut.One.setHighBits(MinVal.countLeadingOnes());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an odd mix of countl_one and countLeadingOnes directly next to each other here.

EXPECT_TRUE(KnownNSWComputed.Zero.isSubsetOf(KnownNSW.Zero));
EXPECT_TRUE(KnownNSWComputed.One.isSubsetOf(KnownNSW.One));
IsAdd, /*NSW=*/true, /*NUW=*/false, Known1, Known2);
if (!KnownNSW.hasConflict())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this still happen? It looks like computeForAddSub now resets on conflict

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, KnownNSW.hasConflict() will be true in the case where all admissible LHS/RHS values violated the nsw constraint. You're right that in this case, ComputedNSW will have been reset.

@goldsteinn
Copy link
Contributor Author

Here's an idea for simplifying the code. I don't know if it's measurably faster. I don't really have time to finish it off at the moment. jayfoad@c854484

It is measurable faster :) https://llvm-compile-time-tracker.com/compare.php?from=1f89a3cde2b407dc77e33ac3092133415ab151a9&to=9d52c0ffea77715ae70f1c66619504230d2da493&stat=instructions:u

Nice! How does that compare to main?

Improvement:
https://llvm-compile-time-tracker.com/compare.php?from=a4951eca40c070e020aa5d2689c08177fbeb780d&to=9d52c0ffea77715ae70f1c66619504230d2da493&stat=instructions%3Au

I bet it has to do with the isUnknown checks.

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No more comments, cheers - LGTM

Just some improvements that should hopefully strengthen analysis.

Closes llvm#83580
@goldsteinn
Copy link
Contributor Author

LGTM. But please file an issue to track the SVE regression.

Done, See: #84046

@goldsteinn
Copy link
Contributor Author

Closed with: 17162b6 (messed up closed tag).

@goldsteinn goldsteinn closed this Mar 5, 2024
if (NUW) {
if (Add) {
// (add nuw X, Y)
APInt MinVal = LHS.getMinValue().uadd_sat(RHS.getMinValue());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From testing, it seems that this can be a normal (non-saturating) + and the usub_sat below can be -. But in the NSW code, the sadd_sat/ssub_sat are required. I don't understand why.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think its b.c this is minval, does replacing the sat in the NSW for minval only work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh but its signed min. You can replace MaxVal = LHS.getSignedMaxValue() + RHS.getSignedMaxValue(); below.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants