Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KnownBits] Implement knownbits lshr/ashr with exact flag #84254

Closed
wants to merge 3 commits into from

Conversation

goldsteinn
Copy link
Contributor

@goldsteinn goldsteinn commented Mar 6, 2024

  • [KnownBits] Add test for computing more information for lshr/ashr with exact flag; NFC
  • [KnownBits] Add API support for exact in lshr/ashr; NFC
  • [KnownBits] Implement knownbits lshr/ashr with exact flag

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 6, 2024

@llvm/pr-subscribers-llvm-selectiondag
@llvm/pr-subscribers-llvm-analysis
@llvm/pr-subscribers-llvm-support

@llvm/pr-subscribers-llvm-globalisel

Author: None (goldsteinn)

Changes
  • [KnownBits] Add API support for exact in lshr/ashr; NFC
  • [KnownBits] Implement knownbits lshr/ashr with exact flag

Full diff: https://github.com/llvm/llvm-project/pull/84254.diff

6 Files Affected:

  • (modified) llvm/include/llvm/Support/KnownBits.h (+2-2)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+8-6)
  • (modified) llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp (+3-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+4-2)
  • (modified) llvm/lib/Support/KnownBits.cpp (+26-2)
  • (modified) llvm/unittests/Support/KnownBitsTest.cpp (+40-1)
diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index 46dbf0c2baa5fe..06d2c90f7b0f6b 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -402,12 +402,12 @@ struct KnownBits {
   /// Compute known bits for lshr(LHS, RHS).
   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
   static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS,
-                        bool ShAmtNonZero = false);
+                        bool ShAmtNonZero = false, bool Exact = false);
 
   /// Compute known bits for ashr(LHS, RHS).
   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
   static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS,
-                        bool ShAmtNonZero = false);
+                        bool ShAmtNonZero = false, bool Exact = false);
 
   /// Determine if these known bits always give the same ICMP_EQ result.
   static std::optional<bool> eq(const KnownBits &LHS, const KnownBits &RHS);
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 52ae9f034e5d34..3304db68e3deae 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -1142,9 +1142,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
     break;
   }
   case Instruction::LShr: {
-    auto KF = [](const KnownBits &KnownVal, const KnownBits &KnownAmt,
-                 bool ShAmtNonZero) {
-      return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero);
+    bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
+    auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
+                      bool ShAmtNonZero) {
+      return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
     };
     computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
                                       KF);
@@ -1155,9 +1156,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
     break;
   }
   case Instruction::AShr: {
-    auto KF = [](const KnownBits &KnownVal, const KnownBits &KnownAmt,
-                 bool ShAmtNonZero) {
-      return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero);
+    bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
+    auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
+                      bool ShAmtNonZero) {
+      return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
     };
     computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
                                       KF);
diff --git a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
index 099bf45b2734cb..21e0b7b2b68fc7 100644
--- a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
@@ -565,7 +565,9 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
     KnownBits ExtKnown = KnownBits::makeConstant(APInt(BitWidth, BitWidth));
     KnownBits ShiftKnown = KnownBits::computeForAddSub(
         /*Add=*/false, /*NSW=*/false, /* NUW=*/false, ExtKnown, WidthKnown);
-    Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown);
+    Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown,
+                            /*ShAmtNonZero=*/false,
+                            /*Exact*/ true);
     break;
   }
   case TargetOpcode::G_UADDO:
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index f7ace79e8c51d4..92e20cc1304b70 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3485,7 +3485,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
   case ISD::SRL:
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
-    Known = KnownBits::lshr(Known, Known2);
+    Known = KnownBits::lshr(Known, Known2, /*ShAmtNonZero=*/false,
+                            Op->getFlags().hasExact());
 
     // Minimum shift high bits are known zero.
     if (const APInt *ShMinAmt =
@@ -3495,7 +3496,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
   case ISD::SRA:
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
-    Known = KnownBits::ashr(Known, Known2);
+    Known = KnownBits::ashr(Known, Known2, /*ShAmtNonZero=*/false,
+                            Op->getFlags().hasExact());
     break;
   case ISD::FSHL:
   case ISD::FSHR:
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 74d857457aec1e..c33c3680825a10 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -343,7 +343,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
 }
 
 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
-                          bool ShAmtNonZero) {
+                          bool ShAmtNonZero, bool Exact) {
   unsigned BitWidth = LHS.getBitWidth();
   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
     KnownBits Known = LHS;
@@ -367,6 +367,18 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
   // Find the common bits from all possible shifts.
   APInt MaxValue = RHS.getMaxValue();
   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+
+  // If exact, bound MaxShiftAmount to first known 1 in LHS.
+  if (Exact) {
+    unsigned FirstOne = LHS.countMaxTrailingZeros();
+    if (FirstOne < MinShiftAmount) {
+      // Always poison. Return zero because we don't like returning conflict.
+      Known.setAllZero();
+      return Known;
+    }
+    MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
+  }
+
   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
   Known.Zero.setAllBits();
@@ -389,7 +401,7 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
 }
 
 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
-                          bool ShAmtNonZero) {
+                          bool ShAmtNonZero, bool Exact) {
   unsigned BitWidth = LHS.getBitWidth();
   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
     KnownBits Known = LHS;
@@ -415,6 +427,18 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
   // Find the common bits from all possible shifts.
   APInt MaxValue = RHS.getMaxValue();
   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+
+  // If exact, bound MaxShiftAmount to first known 1 in LHS.
+  if (Exact) {
+    unsigned FirstOne = LHS.countMaxTrailingZeros();
+    if (FirstOne < MinShiftAmount) {
+      // Always poison. Return zero because we don't like returning conflict.
+      Known.setAllZero();
+      return Known;
+    }
+    MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
+  }
+
   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
   Known.Zero.setAllBits();
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 658f3796721c4e..1816fa8bc49e8a 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -320,6 +320,20 @@ TEST(KnownBitsTest, AbsDiffSpecialCase) {
   EXPECT_EQ(0b0000ul, Res.Zero.getZExtValue());
 }
 
+TEST(KnownBitsTest, ShrExactSpecialCase) {
+    const unsigned N = 4;
+    KnownBits LHS(N), RHS(N);
+
+    LHS.One.setBit(1);
+    LHS.One.setBit(2);
+
+    EXPECT_FALSE(KnownBits::lshr(LHS, RHS).One[1]);
+    EXPECT_FALSE(KnownBits::ashr(LHS, RHS).One[1]);
+
+    EXPECT_FALSE(KnownBits::lshr(LHS, RHS).One[1]);
+    EXPECT_FALSE(KnownBits::ashr(LHS, RHS).One[1]);        
+}
+
 TEST(KnownBitsTest, BinaryExhaustive) {
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
@@ -505,7 +519,6 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         return Res;
       },
       checkOptimalityBinary, /* RefinePoisonToZero */ true);
-
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::lshr(Known1, Known2);
@@ -516,6 +529,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         return N1.lshr(N2);
       },
       checkOptimalityBinary, /* RefinePoisonToZero */ true);
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::lshr(Known1, Known2, /*ShAmtNonZero=*/false,
+                               /*Exact=*/true);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        if (N2.uge(N2.getBitWidth()))
+          return std::nullopt;
+        if (!N2.isZero() && !N1.extractBits(N2.getZExtValue(), 0).isZero())
+          return std::nullopt;
+        return N1.lshr(N2);
+      },
+      checkOptimalityBinary, /* RefinePoisonToZero */ true);
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::ashr(Known1, Known2);
@@ -526,6 +552,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         return N1.ashr(N2);
       },
       checkOptimalityBinary, /* RefinePoisonToZero */ true);
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::ashr(Known1, Known2, /*ShAmtNonZero=*/false,
+                               /*Exact=*/true);
+      },
+      [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+        if (N2.uge(N2.getBitWidth()))
+          return std::nullopt;
+        if (!N2.isZero() && !N1.extractBits(N2.getZExtValue(), 0).isZero())
+          return std::nullopt;
+        return N1.ashr(N2);
+      },
+      checkOptimalityBinary, /* RefinePoisonToZero */ true);
 
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {

@goldsteinn goldsteinn changed the title knownbits lshr ashr [KnownBits] Implement knownbits lshr/ashr with exact flag Mar 6, 2024
Copy link

github-actions bot commented Mar 6, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown);
Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown,
/*ShAmtNonZero=*/false,
/*Exact*/ true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you read the Exact flag from the llvm::MachineInstr? Same for NSW and NUW above. Maybe I am missing something.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am extremely sceptical that this change will make any improvement to the KnownBits analysis for G_SBFX. I'd prefer to just remove this part from the patch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you read the Exact flag from the llvm::MachineInstr?

No, this is not related to whether the G_SBFX itself is exact. This is only exact because we're operating on the result of extractBits, which will leave zero bits outside of the masked value.

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Mar 7, 2024
Copy link
Contributor

@jayfoad jayfoad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems OK overall, but it's a bit disappointing that it doesn't affect any codegen tests.

[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
if (N2.uge(N2.getBitWidth()))
return std::nullopt;
if (!N2.isZero() && !N1.extractBits(N2.getZExtValue(), 0).isZero())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should still work even without the !N2.isZero() check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extractBits has:

  assert(bitPosition < BitWidth && (numBits + bitPosition) <= BitWidth &&
         "Illegal bit extraction");

so need N2 > 0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should relax that assertion, I think. It predates zero-length APInts being made legal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kk

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@goldsteinn huh? How does that assertion prevent calling it with numBits == 0?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Though the assertion probably should be tweaked to bitPosition **<=** BitWidth && (numBits + bitPosition) <= BitWidth, to allow for extracting the zero high-order bits.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah... I misread BitWidth as numBits...., ill update.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although APInt is somewhat littered with things like `assert(BitWidth && "zero width values not allowed");

The usage we have here should be fine, though.

Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown);
Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown,
/*ShAmtNonZero=*/false,
/*Exact*/ true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am extremely sceptical that this change will make any improvement to the KnownBits analysis for G_SBFX. I'd prefer to just remove this part from the patch.

@@ -320,6 +320,22 @@ TEST(KnownBitsTest, AbsDiffSpecialCase) {
EXPECT_EQ(0b0000ul, Res.Zero.getZExtValue());
}

TEST(KnownBitsTest, ShrExactSpecialCase) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what this special case is - why isn't it covered by the exhaustive testing below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just to show a case where exact provides info that non-exact doesn't. But ill change this to an IR test where the codegen actually varies @dtcxzyw's request.

if (Exact) {
unsigned FirstOne = LHS.countMaxTrailingZeros();
if (FirstOne < MinShiftAmount) {
// Always poison. Return zero because we don't like returning conflict.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not your fault, but I really don't like that we work so hard everywhere to avoid returning conflict. We should embrace conflict!

Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown);
Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown,
/*ShAmtNonZero=*/false,
/*Exact*/ true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you read the Exact flag from the llvm::MachineInstr?

No, this is not related to whether the G_SBFX itself is exact. This is only exact because we're operating on the result of extractBits, which will leave zero bits outside of the masked value.

@dtcxzyw
Copy link
Member

dtcxzyw commented Mar 7, 2024

Can this patch improve the optimization? Could you please add some tests for demonstration?

@goldsteinn
Copy link
Contributor Author

Can this patch improve the optimization? Could you please add some tests for demonstration?

Have a unittest that shows when exact can help, but ill change to a codegen test.

The exact flag basically allows us to set an upper bound on shift
amount when we have a known 1 in `LHS`.

Typically we deduce exact using knownbits (on non-exact incoming
shifts), so this is particularly impactful, but may be useful in some
circumstances.
@goldsteinn
Copy link
Contributor Author

Can this patch improve the optimization? Could you please add some tests for demonstration?

Done.

@goldsteinn
Copy link
Contributor Author

ping

Copy link
Contributor

@jayfoad jayfoad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants