Skip to content

Commit

Permalink
[ValueTracking] Checking haveNoCommonBitsSet for (x & y) and ~(x | y)
Browse files Browse the repository at this point in the history
This one tries to fix:
#53357.

Simply, this one would check (x & y) and ~(x | y) in
haveNoCommonBitsSet. Since they shouldn't have common bits (we could
traverse the case by enumerating), and we could convert this one to (x &
y) | ~(x | y) . Then the compiler could handle it in
InstCombineAndOrXor.
Further more, since ((x & y) + (~x & ~y)) would be converted to ((x & y)
+ ~(x | y)), this patch would fix it too.

https://alive2.llvm.org/ce/z/qsKzRS

Reviewed By: spatel, xbolva00, RKSimon, lebedev.ri

Differential Revision: https://reviews.llvm.org/D118094
  • Loading branch information
ChuanqiXu9 committed Feb 16, 2022
1 parent d30ca5e commit a2609be
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 41 deletions.
26 changes: 19 additions & 7 deletions llvm/lib/Analysis/ValueTracking.cpp
Expand Up @@ -275,13 +275,25 @@ bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
assert(LHS->getType()->isIntOrIntVectorTy() &&
"LHS and RHS should be integers");
// Look for an inverted mask: (X & ~M) op (Y & M).
Value *M;
if (match(LHS, m_c_And(m_Not(m_Value(M)), m_Value())) &&
match(RHS, m_c_And(m_Specific(M), m_Value())))
return true;
if (match(RHS, m_c_And(m_Not(m_Value(M)), m_Value())) &&
match(LHS, m_c_And(m_Specific(M), m_Value())))
return true;
{
Value *M;
if (match(LHS, m_c_And(m_Not(m_Value(M)), m_Value())) &&
match(RHS, m_c_And(m_Specific(M), m_Value())))
return true;
if (match(RHS, m_c_And(m_Not(m_Value(M)), m_Value())) &&
match(LHS, m_c_And(m_Specific(M), m_Value())))
return true;
}
// Look for: (A & B) op ~(A | B)
{
Value *A, *B;
if (match(LHS, m_And(m_Value(A), m_Value(B))) &&
match(RHS, m_Not(m_c_Or(m_Specific(A), m_Specific(B)))))
return true;
if (match(RHS, m_And(m_Value(A), m_Value(B))) &&
match(LHS, m_Not(m_c_Or(m_Specific(A), m_Specific(B)))))
return true;
}
IntegerType *IT = cast<IntegerType>(LHS->getType()->getScalarType());
KnownBits LHSKnown(IT->getBitWidth());
KnownBits RHSKnown(IT->getBitWidth());
Expand Down
54 changes: 20 additions & 34 deletions llvm/test/Transforms/InstCombine/pr53357.ll
Expand Up @@ -5,11 +5,9 @@
; (x & y) + ~(x | y)
define i32 @src(i32 %0, i32 %1) {
; CHECK-LABEL: @src(
; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP1:%.*]], [[TMP0:%.*]]
; CHECK-NEXT: [[TMP4:%.*]] = or i32 [[TMP1]], [[TMP0]]
; CHECK-NEXT: [[TMP5:%.*]] = xor i32 [[TMP4]], -1
; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP3]], [[TMP5]]
; CHECK-NEXT: ret i32 [[TMP6]]
; CHECK-NEXT: [[TMP3:%.*]] = xor i32 [[TMP1:%.*]], [[TMP0:%.*]]
; CHECK-NEXT: [[TMP4:%.*]] = xor i32 [[TMP3]], -1
; CHECK-NEXT: ret i32 [[TMP4]]
;
%3 = and i32 %1, %0
%4 = or i32 %1, %0
Expand All @@ -21,11 +19,9 @@ define i32 @src(i32 %0, i32 %1) {
; vector version of src
define <2 x i32> @src_vec(<2 x i32> %0, <2 x i32> %1) {
; CHECK-LABEL: @src_vec(
; CHECK-NEXT: [[TMP3:%.*]] = and <2 x i32> [[TMP1:%.*]], [[TMP0:%.*]]
; CHECK-NEXT: [[TMP4:%.*]] = or <2 x i32> [[TMP1]], [[TMP0]]
; CHECK-NEXT: [[TMP5:%.*]] = xor <2 x i32> [[TMP4]], <i32 -1, i32 -1>
; CHECK-NEXT: [[TMP6:%.*]] = add <2 x i32> [[TMP3]], [[TMP5]]
; CHECK-NEXT: ret <2 x i32> [[TMP6]]
; CHECK-NEXT: [[TMP3:%.*]] = xor <2 x i32> [[TMP1:%.*]], [[TMP0:%.*]]
; CHECK-NEXT: [[TMP4:%.*]] = xor <2 x i32> [[TMP3]], <i32 -1, i32 -1>
; CHECK-NEXT: ret <2 x i32> [[TMP4]]
;
%3 = and <2 x i32> %1, %0
%4 = or <2 x i32> %1, %0
Expand All @@ -37,11 +33,9 @@ define <2 x i32> @src_vec(<2 x i32> %0, <2 x i32> %1) {
; vector version of src with undef values
define <2 x i32> @src_vec_undef(<2 x i32> %0, <2 x i32> %1) {
; CHECK-LABEL: @src_vec_undef(
; CHECK-NEXT: [[TMP3:%.*]] = and <2 x i32> [[TMP1:%.*]], [[TMP0:%.*]]
; CHECK-NEXT: [[TMP4:%.*]] = or <2 x i32> [[TMP1]], [[TMP0]]
; CHECK-NEXT: [[TMP5:%.*]] = xor <2 x i32> [[TMP4]], <i32 -1, i32 undef>
; CHECK-NEXT: [[TMP6:%.*]] = add <2 x i32> [[TMP3]], [[TMP5]]
; CHECK-NEXT: ret <2 x i32> [[TMP6]]
; CHECK-NEXT: [[TMP3:%.*]] = xor <2 x i32> [[TMP1:%.*]], [[TMP0:%.*]]
; CHECK-NEXT: [[TMP4:%.*]] = xor <2 x i32> [[TMP3]], <i32 -1, i32 -1>
; CHECK-NEXT: ret <2 x i32> [[TMP4]]
;
%3 = and <2 x i32> %1, %0
%4 = or <2 x i32> %1, %0
Expand All @@ -53,11 +47,9 @@ define <2 x i32> @src_vec_undef(<2 x i32> %0, <2 x i32> %1) {
; (x & y) + ~(y | x)
define i32 @src2(i32 %0, i32 %1) {
; CHECK-LABEL: @src2(
; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP1:%.*]], [[TMP0:%.*]]
; CHECK-NEXT: [[TMP4:%.*]] = or i32 [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[TMP5:%.*]] = xor i32 [[TMP4]], -1
; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP3]], [[TMP5]]
; CHECK-NEXT: ret i32 [[TMP6]]
; CHECK-NEXT: [[TMP3:%.*]] = xor i32 [[TMP1:%.*]], [[TMP0:%.*]]
; CHECK-NEXT: [[TMP4:%.*]] = xor i32 [[TMP3]], -1
; CHECK-NEXT: ret i32 [[TMP4]]
;
%3 = and i32 %1, %0
%4 = or i32 %0, %1
Expand All @@ -69,11 +61,9 @@ define i32 @src2(i32 %0, i32 %1) {
; (x & y) + (~x & ~y)
define i32 @src3(i32 %0, i32 %1) {
; CHECK-LABEL: @src3(
; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP1:%.*]], [[TMP0:%.*]]
; CHECK-NEXT: [[DOTDEMORGAN:%.*]] = or i32 [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[TMP4:%.*]] = xor i32 [[DOTDEMORGAN]], -1
; CHECK-NEXT: [[TMP5:%.*]] = add i32 [[TMP3]], [[TMP4]]
; CHECK-NEXT: ret i32 [[TMP5]]
; CHECK-NEXT: [[TMP3:%.*]] = xor i32 [[TMP1:%.*]], [[TMP0:%.*]]
; CHECK-NEXT: [[TMP4:%.*]] = xor i32 [[TMP3]], -1
; CHECK-NEXT: ret i32 [[TMP4]]
;
%3 = and i32 %1, %0
%4 = xor i32 %0, -1
Expand All @@ -86,11 +76,9 @@ define i32 @src3(i32 %0, i32 %1) {
; ~(x | y) + (y & x)
define i32 @src4(i32 %0, i32 %1) {
; CHECK-LABEL: @src4(
; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP0:%.*]], [[TMP1:%.*]]
; CHECK-NEXT: [[TMP4:%.*]] = or i32 [[TMP1]], [[TMP0]]
; CHECK-NEXT: [[TMP5:%.*]] = xor i32 [[TMP4]], -1
; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP3]], [[TMP5]]
; CHECK-NEXT: ret i32 [[TMP6]]
; CHECK-NEXT: [[TMP3:%.*]] = xor i32 [[TMP0:%.*]], [[TMP1:%.*]]
; CHECK-NEXT: [[TMP4:%.*]] = xor i32 [[TMP3]], -1
; CHECK-NEXT: ret i32 [[TMP4]]
;
%3 = and i32 %0, %1
%4 = or i32 %1, %0
Expand All @@ -102,11 +90,9 @@ define i32 @src4(i32 %0, i32 %1) {
; ~(x | y) + (x & y)
define i32 @src5(i32 %0, i32 %1) {
; CHECK-LABEL: @src5(
; CHECK-NEXT: [[TMP3:%.*]] = or i32 [[TMP1:%.*]], [[TMP0:%.*]]
; CHECK-NEXT: [[TMP3:%.*]] = xor i32 [[TMP1:%.*]], [[TMP0:%.*]]
; CHECK-NEXT: [[TMP4:%.*]] = xor i32 [[TMP3]], -1
; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP1]], [[TMP0]]
; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP5]], [[TMP4]]
; CHECK-NEXT: ret i32 [[TMP6]]
; CHECK-NEXT: ret i32 [[TMP4]]
;
%3 = or i32 %1, %0
%4 = xor i32 %3, -1
Expand Down
56 changes: 56 additions & 0 deletions llvm/unittests/Analysis/ValueTrackingTest.cpp
Expand Up @@ -1745,6 +1745,62 @@ TEST_F(ValueTrackingTest, HaveNoCommonBitsSet) {
EXPECT_TRUE(haveNoCommonBitsSet(LHS, RHS, DL));
EXPECT_TRUE(haveNoCommonBitsSet(RHS, LHS, DL));
}
{
// Check for (A & B) and ~(A | B)
auto M = parseModule(R"(
define void @test(i32 %A, i32 %B) {
%LHS = and i32 %A, %B
%or = or i32 %A, %B
%RHS = xor i32 %or, -1
%LHS2 = and i32 %B, %A
%or2 = or i32 %A, %B
%RHS2 = xor i32 %or2, -1
ret void
})");

auto *F = M->getFunction("test");
const DataLayout &DL = M->getDataLayout();

auto *LHS = findInstructionByNameOrNull(F, "LHS");
auto *RHS = findInstructionByNameOrNull(F, "RHS");
EXPECT_TRUE(haveNoCommonBitsSet(LHS, RHS, DL));
EXPECT_TRUE(haveNoCommonBitsSet(RHS, LHS, DL));

auto *LHS2 = findInstructionByNameOrNull(F, "LHS2");
auto *RHS2 = findInstructionByNameOrNull(F, "RHS2");
EXPECT_TRUE(haveNoCommonBitsSet(LHS2, RHS2, DL));
EXPECT_TRUE(haveNoCommonBitsSet(RHS2, LHS2, DL));
}
{
// Check for (A & B) and ~(A | B) in vector version
auto M = parseModule(R"(
define void @test(<2 x i32> %A, <2 x i32> %B) {
%LHS = and <2 x i32> %A, %B
%or = or <2 x i32> %A, %B
%RHS = xor <2 x i32> %or, <i32 -1, i32 -1>
%LHS2 = and <2 x i32> %B, %A
%or2 = or <2 x i32> %A, %B
%RHS2 = xor <2 x i32> %or2, <i32 -1, i32 -1>
ret void
})");

auto *F = M->getFunction("test");
const DataLayout &DL = M->getDataLayout();

auto *LHS = findInstructionByNameOrNull(F, "LHS");
auto *RHS = findInstructionByNameOrNull(F, "RHS");
EXPECT_TRUE(haveNoCommonBitsSet(LHS, RHS, DL));
EXPECT_TRUE(haveNoCommonBitsSet(RHS, LHS, DL));

auto *LHS2 = findInstructionByNameOrNull(F, "LHS2");
auto *RHS2 = findInstructionByNameOrNull(F, "RHS2");
EXPECT_TRUE(haveNoCommonBitsSet(LHS2, RHS2, DL));
EXPECT_TRUE(haveNoCommonBitsSet(RHS2, LHS2, DL));
}
}

class IsBytewiseValueTest : public ValueTrackingTest,
Expand Down

0 comments on commit a2609be

Please sign in to comment.