-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[KnownBits] Implement knownbits lshr/ashr with exact flag #84254
Conversation
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-llvm-globalisel Author: None (goldsteinn) Changes
Full diff: https://github.com/llvm/llvm-project/pull/84254.diff 6 Files Affected:
diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index 46dbf0c2baa5fe..06d2c90f7b0f6b 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -402,12 +402,12 @@ struct KnownBits {
/// Compute known bits for lshr(LHS, RHS).
/// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero = false);
+ bool ShAmtNonZero = false, bool Exact = false);
/// Compute known bits for ashr(LHS, RHS).
/// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero = false);
+ bool ShAmtNonZero = false, bool Exact = false);
/// Determine if these known bits always give the same ICMP_EQ result.
static std::optional<bool> eq(const KnownBits &LHS, const KnownBits &RHS);
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 52ae9f034e5d34..3304db68e3deae 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -1142,9 +1142,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Instruction::LShr: {
- auto KF = [](const KnownBits &KnownVal, const KnownBits &KnownAmt,
- bool ShAmtNonZero) {
- return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero);
+ bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
+ auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
+ bool ShAmtNonZero) {
+ return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
};
computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
KF);
@@ -1155,9 +1156,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Instruction::AShr: {
- auto KF = [](const KnownBits &KnownVal, const KnownBits &KnownAmt,
- bool ShAmtNonZero) {
- return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero);
+ bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
+ auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
+ bool ShAmtNonZero) {
+ return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
};
computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
KF);
diff --git a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
index 099bf45b2734cb..21e0b7b2b68fc7 100644
--- a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
@@ -565,7 +565,9 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
KnownBits ExtKnown = KnownBits::makeConstant(APInt(BitWidth, BitWidth));
KnownBits ShiftKnown = KnownBits::computeForAddSub(
/*Add=*/false, /*NSW=*/false, /* NUW=*/false, ExtKnown, WidthKnown);
- Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown);
+ Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown,
+ /*ShAmtNonZero=*/false,
+ /*Exact*/ true);
break;
}
case TargetOpcode::G_UADDO:
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index f7ace79e8c51d4..92e20cc1304b70 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3485,7 +3485,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
case ISD::SRL:
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
- Known = KnownBits::lshr(Known, Known2);
+ Known = KnownBits::lshr(Known, Known2, /*ShAmtNonZero=*/false,
+ Op->getFlags().hasExact());
// Minimum shift high bits are known zero.
if (const APInt *ShMinAmt =
@@ -3495,7 +3496,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
case ISD::SRA:
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
- Known = KnownBits::ashr(Known, Known2);
+ Known = KnownBits::ashr(Known, Known2, /*ShAmtNonZero=*/false,
+ Op->getFlags().hasExact());
break;
case ISD::FSHL:
case ISD::FSHR:
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 74d857457aec1e..c33c3680825a10 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -343,7 +343,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
}
KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero) {
+ bool ShAmtNonZero, bool Exact) {
unsigned BitWidth = LHS.getBitWidth();
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
KnownBits Known = LHS;
@@ -367,6 +367,18 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
// Find the common bits from all possible shifts.
APInt MaxValue = RHS.getMaxValue();
unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+
+ // If exact, bound MaxShiftAmount to first known 1 in LHS.
+ if (Exact) {
+ unsigned FirstOne = LHS.countMaxTrailingZeros();
+ if (FirstOne < MinShiftAmount) {
+ // Always poison. Return zero because we don't like returning conflict.
+ Known.setAllZero();
+ return Known;
+ }
+ MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
+ }
+
unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
Known.Zero.setAllBits();
@@ -389,7 +401,7 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
}
KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
- bool ShAmtNonZero) {
+ bool ShAmtNonZero, bool Exact) {
unsigned BitWidth = LHS.getBitWidth();
auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
KnownBits Known = LHS;
@@ -415,6 +427,18 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
// Find the common bits from all possible shifts.
APInt MaxValue = RHS.getMaxValue();
unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+
+ // If exact, bound MaxShiftAmount to first known 1 in LHS.
+ if (Exact) {
+ unsigned FirstOne = LHS.countMaxTrailingZeros();
+ if (FirstOne < MinShiftAmount) {
+ // Always poison. Return zero because we don't like returning conflict.
+ Known.setAllZero();
+ return Known;
+ }
+ MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
+ }
+
unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
Known.Zero.setAllBits();
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 658f3796721c4e..1816fa8bc49e8a 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -320,6 +320,20 @@ TEST(KnownBitsTest, AbsDiffSpecialCase) {
EXPECT_EQ(0b0000ul, Res.Zero.getZExtValue());
}
+TEST(KnownBitsTest, ShrExactSpecialCase) {
+ const unsigned N = 4;
+ KnownBits LHS(N), RHS(N);
+
+ LHS.One.setBit(1);
+ LHS.One.setBit(2);
+
+ EXPECT_FALSE(KnownBits::lshr(LHS, RHS).One[1]);
+ EXPECT_FALSE(KnownBits::ashr(LHS, RHS).One[1]);
+
+ EXPECT_FALSE(KnownBits::lshr(LHS, RHS).One[1]);
+ EXPECT_FALSE(KnownBits::ashr(LHS, RHS).One[1]);
+}
+
TEST(KnownBitsTest, BinaryExhaustive) {
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
@@ -505,7 +519,6 @@ TEST(KnownBitsTest, BinaryExhaustive) {
return Res;
},
checkOptimalityBinary, /* RefinePoisonToZero */ true);
-
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
return KnownBits::lshr(Known1, Known2);
@@ -516,6 +529,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
return N1.lshr(N2);
},
checkOptimalityBinary, /* RefinePoisonToZero */ true);
+ testBinaryOpExhaustive(
+ [](const KnownBits &Known1, const KnownBits &Known2) {
+ return KnownBits::lshr(Known1, Known2, /*ShAmtNonZero=*/false,
+ /*Exact=*/true);
+ },
+ [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+ if (N2.uge(N2.getBitWidth()))
+ return std::nullopt;
+ if (!N2.isZero() && !N1.extractBits(N2.getZExtValue(), 0).isZero())
+ return std::nullopt;
+ return N1.lshr(N2);
+ },
+ checkOptimalityBinary, /* RefinePoisonToZero */ true);
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
return KnownBits::ashr(Known1, Known2);
@@ -526,6 +552,19 @@ TEST(KnownBitsTest, BinaryExhaustive) {
return N1.ashr(N2);
},
checkOptimalityBinary, /* RefinePoisonToZero */ true);
+ testBinaryOpExhaustive(
+ [](const KnownBits &Known1, const KnownBits &Known2) {
+ return KnownBits::ashr(Known1, Known2, /*ShAmtNonZero=*/false,
+ /*Exact=*/true);
+ },
+ [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+ if (N2.uge(N2.getBitWidth()))
+ return std::nullopt;
+ if (!N2.isZero() && !N1.extractBits(N2.getZExtValue(), 0).isZero())
+ return std::nullopt;
+ return N1.ashr(N2);
+ },
+ checkOptimalityBinary, /* RefinePoisonToZero */ true);
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
40ae0a4
to
b1f8743
Compare
Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown); | ||
Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown, | ||
/*ShAmtNonZero=*/false, | ||
/*Exact*/ true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you read the Exact
flag from the llvm::MachineInstr? Same for NSW and NUW above. Maybe I am missing something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am extremely sceptical that this change will make any improvement to the KnownBits analysis for G_SBFX. I'd prefer to just remove this part from the patch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you read the Exact flag from the llvm::MachineInstr?
No, this is not related to whether the G_SBFX itself is exact. This is only exact because we're operating on the result of extractBits, which will leave zero bits outside of the masked value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems OK overall, but it's a bit disappointing that it doesn't affect any codegen tests.
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> { | ||
if (N2.uge(N2.getBitWidth())) | ||
return std::nullopt; | ||
if (!N2.isZero() && !N1.extractBits(N2.getZExtValue(), 0).isZero()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should still work even without the !N2.isZero()
check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extractBits
has:
assert(bitPosition < BitWidth && (numBits + bitPosition) <= BitWidth &&
"Illegal bit extraction");
so need N2 > 0
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should relax that assertion, I think. It predates zero-length APInts being made legal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@goldsteinn huh? How does that assertion prevent calling it with numBits == 0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Though the assertion probably should be tweaked to bitPosition **<=** BitWidth && (numBits + bitPosition) <= BitWidth
, to allow for extracting the zero high-order bits.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah... I misread BitWidth
as numBits
...., ill update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although APInt
is somewhat littered with things like `assert(BitWidth && "zero width values not allowed");
The usage we have here should be fine, though.
Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown); | ||
Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown, | ||
/*ShAmtNonZero=*/false, | ||
/*Exact*/ true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am extremely sceptical that this change will make any improvement to the KnownBits analysis for G_SBFX. I'd prefer to just remove this part from the patch.
@@ -320,6 +320,22 @@ TEST(KnownBitsTest, AbsDiffSpecialCase) { | |||
EXPECT_EQ(0b0000ul, Res.Zero.getZExtValue()); | |||
} | |||
|
|||
TEST(KnownBitsTest, ShrExactSpecialCase) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand what this special case is - why isn't it covered by the exhaustive testing below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was just to show a case where exact provides info that non-exact doesn't. But ill change this to an IR test where the codegen actually varies @dtcxzyw's request.
if (Exact) { | ||
unsigned FirstOne = LHS.countMaxTrailingZeros(); | ||
if (FirstOne < MinShiftAmount) { | ||
// Always poison. Return zero because we don't like returning conflict. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not your fault, but I really don't like that we work so hard everywhere to avoid returning conflict. We should embrace conflict!
Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown); | ||
Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown, | ||
/*ShAmtNonZero=*/false, | ||
/*Exact*/ true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you read the Exact flag from the llvm::MachineInstr?
No, this is not related to whether the G_SBFX itself is exact. This is only exact because we're operating on the result of extractBits, which will leave zero bits outside of the masked value.
Can this patch improve the optimization? Could you please add some tests for demonstration? |
Have a unittest that shows when |
… with `exact` flag; NFC
The exact flag basically allows us to set an upper bound on shift amount when we have a known 1 in `LHS`. Typically we deduce exact using knownbits (on non-exact incoming shifts), so this is particularly impactful, but may be useful in some circumstances.
b1f8743
to
895ce67
Compare
Done. |
ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
lshr
/ashr
withexact
flag; NFCexact
inlshr
/ashr
; NFClshr
/ashr
with exact flag