diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 6745b1b394cfe6..aa60b0fa58cc93 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -2706,43 +2706,45 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { // (A & MaskC0) | (B & MaskC1) const APInt *MaskC0, *MaskC1; - if (match(C, m_APInt(MaskC0)) && match(D, m_APInt(MaskC1)) && - *MaskC0 == ~*MaskC1) { + if (match(C, m_APInt(MaskC0)) && match(D, m_APInt(MaskC1))) { Value *X; + if (*MaskC0 == ~*MaskC1) { + // ((X | B) & MaskC) | (B & ~MaskC) -> (X & MaskC) | B + if (match(A, m_c_Or(m_Value(X), m_Specific(B)))) + return BinaryOperator::CreateOr(Builder.CreateAnd(X, *MaskC0), B); + // (A & MaskC) | ((X | A) & ~MaskC) -> (X & ~MaskC) | A + if (match(B, m_c_Or(m_Specific(A), m_Value(X)))) + return BinaryOperator::CreateOr(Builder.CreateAnd(X, *MaskC1), A); + + // ((X ^ B) & MaskC) | (B & ~MaskC) -> (X & MaskC) ^ B + if (match(A, m_c_Xor(m_Value(X), m_Specific(B)))) + return BinaryOperator::CreateXor(Builder.CreateAnd(X, *MaskC0), B); + // (A & MaskC) | ((X ^ A) & ~MaskC) -> (X & ~MaskC) ^ A + if (match(B, m_c_Xor(m_Specific(A), m_Value(X)))) + return BinaryOperator::CreateXor(Builder.CreateAnd(X, *MaskC1), A); + } - // ((X | B) & MaskC) | (B & ~MaskC) -> (X & MaskC) | B - if (match(A, m_c_Or(m_Value(X), m_Specific(B)))) - return BinaryOperator::CreateOr(Builder.CreateAnd(X, *MaskC0), B); - // (A & MaskC) | ((X | A) & ~MaskC) -> (X & ~MaskC) | A - if (match(B, m_c_Or(m_Specific(A), m_Value(X)))) - return BinaryOperator::CreateOr(Builder.CreateAnd(X, *MaskC1), A); + if ((*MaskC0 & *MaskC1).isZero()) { + // ((X | B) & C1) | (B & C2) --> (X | B) & (C1 | C2) + // iff (C1 & C2) == 0 and (X & ~C1) == 0 + if (match(A, m_c_Or(m_Value(X), m_Specific(B))) && + MaskedValueIsZero(X, ~*MaskC0, 0, &I)) + return BinaryOperator::CreateAnd( + A, ConstantInt::get(I.getType(), *MaskC0 | *MaskC1)); - // ((X ^ B) & MaskC) | (B & ~MaskC) -> (X & MaskC) ^ B - if (match(A, m_c_Xor(m_Value(X), m_Specific(B)))) - return BinaryOperator::CreateXor(Builder.CreateAnd(X, *MaskC0), B); - // (A & MaskC) | ((X ^ A) & ~MaskC) -> (X & ~MaskC) ^ A - if (match(B, m_c_Xor(m_Specific(A), m_Value(X)))) - return BinaryOperator::CreateXor(Builder.CreateAnd(X, *MaskC1), A); + // (A & C1) | ((X | A) & C2) --> (X | A) & (C1 | C2) + // iff (C1 & C2) == 0 and (X & ~C1) == 0 + if (match(B, m_c_Or(m_Value(X), m_Specific(A))) && + MaskedValueIsZero(X, ~*MaskC1, 0, &I)) + return BinaryOperator::CreateAnd( + B, ConstantInt::get(I.getType(), *MaskC0 | *MaskC1)); + } } // (A & C1)|(B & C2) ConstantInt *C1, *C2; if (match(C, m_ConstantInt(C1)) && match(D, m_ConstantInt(C2))) { - Value *N; if ((C1->getValue() & C2->getValue()).isZero()) { - // ((B | N) & C1) | (B & C2) --> (B | N) & (C1 | C2) - // iff (C1 & C2) == 0 and (N & ~C1) == 0 - if (match(A, m_c_Or(m_Specific(B), m_Value(N))) && - MaskedValueIsZero(N, ~C1->getValue(), 0, &I)) - return BinaryOperator::CreateAnd( - A, Builder.getInt(C1->getValue() | C2->getValue())); - // (A & C1) | ((A | N) & C2) --> (A | N) & (C1 | C2) - // iff (C1 & C2) == 0 and (N & ~C1) == 0 - if (match(B, m_c_Or(m_Specific(A), m_Value(N))) && - MaskedValueIsZero(N, ~C2->getValue(), 0, &I)) - return BinaryOperator::CreateAnd( - B, Builder.getInt(C1->getValue() | C2->getValue())); - // ((V|C3)&C1) | ((V|C4)&C2) --> (V|C3|C4)&(C1|C2) // iff (C1&C2) == 0 and (C3&~C1) == 0 and (C4&~C2) == 0. Value *V1 = nullptr, *V2 = nullptr; diff --git a/llvm/test/Transforms/InstCombine/and-or.ll b/llvm/test/Transforms/InstCombine/and-or.ll index 7996fd3c63f616..1e13be01491949 100644 --- a/llvm/test/Transforms/InstCombine/and-or.ll +++ b/llvm/test/Transforms/InstCombine/and-or.ll @@ -124,7 +124,7 @@ define <2 x i8> @or_and_or_commute1_splat(<2 x i8> %x) { ; CHECK-NEXT: call void @use_vec(<2 x i8> [[X1]]) ; CHECK-NEXT: [[X2:%.*]] = and <2 x i8> [[X]], ; CHECK-NEXT: call void @use_vec(<2 x i8> [[X2]]) -; CHECK-NEXT: [[R:%.*]] = or <2 x i8> [[X2]], [[X1]] +; CHECK-NEXT: [[R:%.*]] = and <2 x i8> [[XN]], ; CHECK-NEXT: ret <2 x i8> [[R]] ; %xn = or <2 x i8> %x, @@ -169,7 +169,7 @@ define <2 x i8> @or_and_or_commute2_splat(<2 x i8> %x, <2 x i8> %y) { ; CHECK-NEXT: call void @use_vec(<2 x i8> [[X1]]) ; CHECK-NEXT: [[X2:%.*]] = and <2 x i8> [[X]], ; CHECK-NEXT: call void @use_vec(<2 x i8> [[X2]]) -; CHECK-NEXT: [[R:%.*]] = or <2 x i8> [[X1]], [[X2]] +; CHECK-NEXT: [[R:%.*]] = and <2 x i8> [[XN]], ; CHECK-NEXT: ret <2 x i8> [[R]] ; %n = lshr <2 x i8> %y,