Skip to content

Commit

Permalink
[KnownBits] Make abdu and abds optimal (#89081)
Browse files Browse the repository at this point in the history
Fixes #84212
  • Loading branch information
jayfoad committed Apr 18, 2024
1 parent 8a21d59 commit d8a26ca
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 84 deletions.
2 changes: 1 addition & 1 deletion llvm/include/llvm/Support/KnownBits.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ struct KnownBits {
static KnownBits abdu(const KnownBits &LHS, const KnownBits &RHS);

/// Compute known bits for abds(LHS, RHS).
static KnownBits abds(const KnownBits &LHS, const KnownBits &RHS);
static KnownBits abds(KnownBits LHS, KnownBits RHS);

/// Compute known bits for shl(LHS, RHS).
/// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
Expand Down
66 changes: 39 additions & 27 deletions llvm/lib/Support/KnownBits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,41 +232,53 @@ KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
}

KnownBits KnownBits::abdu(const KnownBits &LHS, const KnownBits &RHS) {
// abdu(LHS,RHS) = sub(umax(LHS,RHS), umin(LHS,RHS)).
KnownBits UMaxValue = umax(LHS, RHS);
KnownBits UMinValue = umin(LHS, RHS);
KnownBits MinMaxDiff = computeForAddSub(/*Add=*/false, /*NSW=*/false,
/*NUW=*/true, UMaxValue, UMinValue);
// If we know which argument is larger, return (sub LHS, RHS) or
// (sub RHS, LHS) directly.
if (LHS.getMinValue().uge(RHS.getMaxValue()))
return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS,
RHS);
if (RHS.getMinValue().uge(LHS.getMaxValue()))
return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS,
LHS);

// find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
// By construction, the subtraction in abdu never has unsigned overflow.
// Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS).
KnownBits Diff0 =
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS);
KnownBits Diff1 =
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS, LHS);
KnownBits SubDiff = Diff0.intersectWith(Diff1);

KnownBits KnownAbsDiff = MinMaxDiff.unionWith(SubDiff);
assert(!KnownAbsDiff.hasConflict() && "Bad Output");
return KnownAbsDiff;
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, RHS, LHS);
return Diff0.intersectWith(Diff1);
}

KnownBits KnownBits::abds(const KnownBits &LHS, const KnownBits &RHS) {
// abds(LHS,RHS) = sub(smax(LHS,RHS), smin(LHS,RHS)).
KnownBits SMaxValue = smax(LHS, RHS);
KnownBits SMinValue = smin(LHS, RHS);
KnownBits MinMaxDiff = computeForAddSub(/*Add=*/false, /*NSW=*/false,
/*NUW=*/false, SMaxValue, SMinValue);
KnownBits KnownBits::abds(KnownBits LHS, KnownBits RHS) {
// If we know which argument is larger, return (sub LHS, RHS) or
// (sub RHS, LHS) directly.
if (LHS.getSignedMinValue().sge(RHS.getSignedMaxValue()))
return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS,
RHS);
if (RHS.getSignedMinValue().sge(LHS.getSignedMaxValue()))
return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS,
LHS);

// Shift both arguments from the signed range to the unsigned range, e.g. from
// [-0x80, 0x7F] to [0, 0xFF]. This allows us to use "sub nuw" below just like
// abdu does.
// Note that we can't just use "sub nsw" instead because abds has signed
// inputs but an unsigned result, which makes the overflow conditions
// different.
unsigned SignBitPosition = LHS.getBitWidth() - 1;
for (auto Arg : {&LHS, &RHS}) {
bool Tmp = Arg->Zero[SignBitPosition];
Arg->Zero.setBitVal(SignBitPosition, Arg->One[SignBitPosition]);
Arg->One.setBitVal(SignBitPosition, Tmp);
}

// find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
// Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS).
KnownBits Diff0 =
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS);
KnownBits Diff1 =
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS, LHS);
KnownBits SubDiff = Diff0.intersectWith(Diff1);

KnownBits KnownAbsDiff = MinMaxDiff.unionWith(SubDiff);
assert(!KnownAbsDiff.hasConflict() && "Bad Output");
return KnownAbsDiff;
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, RHS, LHS);
return Diff0.intersectWith(Diff1);
}

static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) {
Expand Down
58 changes: 2 additions & 56 deletions llvm/unittests/Support/KnownBitsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,58 +294,6 @@ TEST(KnownBitsTest, SignBitUnknown) {
EXPECT_TRUE(Known.isSignUnknown());
}

TEST(KnownBitsTest, ABDUSpecialCase) {
// There are 2 implementations of abdu - both are currently needed to cover
// extra cases.
KnownBits LHS, RHS, Res;

// abdu(LHS,RHS) = sub(umax(LHS,RHS), umin(LHS,RHS)).
// Actual: false (Inputs = 1011, 101?, Computed = 000?, Exact = 000?)
LHS.One = APInt(4, 0b1011);
RHS.One = APInt(4, 0b1010);
LHS.Zero = APInt(4, 0b0100);
RHS.Zero = APInt(4, 0b0100);
Res = KnownBits::abdu(LHS, RHS);
EXPECT_EQ(0b0000ul, Res.One.getZExtValue());
EXPECT_EQ(0b1110ul, Res.Zero.getZExtValue());

// find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
// Actual: false (Inputs = ???1, 1000, Computed = ???1, Exact = 0??1)
LHS.One = APInt(4, 0b0001);
RHS.One = APInt(4, 0b1000);
LHS.Zero = APInt(4, 0b0000);
RHS.Zero = APInt(4, 0b0111);
Res = KnownBits::abdu(LHS, RHS);
EXPECT_EQ(0b0001ul, Res.One.getZExtValue());
EXPECT_EQ(0b0000ul, Res.Zero.getZExtValue());
}

TEST(KnownBitsTest, ABDSSpecialCase) {
// There are 2 implementations of abds - both are currently needed to cover
// extra cases.
KnownBits LHS, RHS, Res;

// abds(LHS,RHS) = sub(smax(LHS,RHS), smin(LHS,RHS)).
// Actual: false (Inputs = 1011, 10??, Computed = ????, Exact = 00??)
LHS.One = APInt(4, 0b1011);
RHS.One = APInt(4, 0b1000);
LHS.Zero = APInt(4, 0b0100);
RHS.Zero = APInt(4, 0b0100);
Res = KnownBits::abds(LHS, RHS);
EXPECT_EQ(0, Res.One.getSExtValue());
EXPECT_EQ(-4, Res.Zero.getSExtValue());

// find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
// Actual: false (Inputs = ???1, 1000, Computed = ???1, Exact = 0??1)
LHS.One = APInt(4, 0b0001);
RHS.One = APInt(4, 0b1000);
LHS.Zero = APInt(4, 0b0000);
RHS.Zero = APInt(4, 0b0111);
Res = KnownBits::abds(LHS, RHS);
EXPECT_EQ(1, Res.One.getSExtValue());
EXPECT_EQ(0, Res.Zero.getSExtValue());
}

TEST(KnownBitsTest, BinaryExhaustive) {
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
Expand All @@ -366,10 +314,8 @@ TEST(KnownBitsTest, BinaryExhaustive) {
testBinaryOpExhaustive(KnownBits::umin, APIntOps::umin);
testBinaryOpExhaustive(KnownBits::smax, APIntOps::smax);
testBinaryOpExhaustive(KnownBits::smin, APIntOps::smin);
testBinaryOpExhaustive(KnownBits::abdu, APIntOps::abdu,
checkCorrectnessOnlyBinary);
testBinaryOpExhaustive(KnownBits::abds, APIntOps::abds,
checkCorrectnessOnlyBinary);
testBinaryOpExhaustive(KnownBits::abdu, APIntOps::abdu);
testBinaryOpExhaustive(KnownBits::abds, APIntOps::abds);
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
return KnownBits::udiv(Known1, Known2);
Expand Down

0 comments on commit d8a26ca

Please sign in to comment.