diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 51bdbb6206a8d2..6d5b220f0694a3 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -798,36 +798,30 @@ foldAndOrOfEqualityCmpsWithConstants(ICmpInst *LHS, ICmpInst *RHS, // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) Value *InstCombinerImpl::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, - BinaryOperator &Logic) { - bool JoinedByAnd = Logic.getOpcode() == Instruction::And; - assert((JoinedByAnd || Logic.getOpcode() == Instruction::Or) && - "Wrong opcode"); - ICmpInst::Predicate Pred = LHS->getPredicate(); - if (Pred != RHS->getPredicate()) - return nullptr; - if (JoinedByAnd && Pred != ICmpInst::ICMP_NE) - return nullptr; - if (!JoinedByAnd && Pred != ICmpInst::ICMP_EQ) + Instruction *CxtI, + bool IsAnd) { + CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_NE : CmpInst::ICMP_EQ; + if (LHS->getPredicate() != Pred || RHS->getPredicate() != Pred) return nullptr; if (!match(LHS->getOperand(1), m_Zero()) || !match(RHS->getOperand(1), m_Zero())) return nullptr; - Value *A, *B, *C, *D; - if (match(LHS->getOperand(0), m_And(m_Value(A), m_Value(B))) && - match(RHS->getOperand(0), m_And(m_Value(C), m_Value(D)))) { - if (A == D || B == D) - std::swap(C, D); - if (B == C) - std::swap(A, B); - - if (A == C && - isKnownToBeAPowerOfTwo(B, false, 0, &Logic) && - isKnownToBeAPowerOfTwo(D, false, 0, &Logic)) { - Value *Mask = Builder.CreateOr(B, D); - Value *Masked = Builder.CreateAnd(A, Mask); - auto NewPred = JoinedByAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; + Value *L1, *L2, *R1, *R2; + if (match(LHS->getOperand(0), m_And(m_Value(L1), m_Value(L2))) && + match(RHS->getOperand(0), m_And(m_Value(R1), m_Value(R2)))) { + if (L1 == R2 || L2 == R2) + std::swap(R1, R2); + if (L2 == R1) + std::swap(L1, L2); + + if (L1 == R1 && + isKnownToBeAPowerOfTwo(L2, false, 0, CxtI) && + isKnownToBeAPowerOfTwo(R2, false, 0, CxtI)) { + Value *Mask = Builder.CreateOr(L2, R2); + Value *Masked = Builder.CreateAnd(L1, Mask); + auto NewPred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE; return Builder.CreateICmp(NewPred, Masked, Mask); } } @@ -1210,7 +1204,8 @@ Value *InstCombinerImpl::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) // if K1 and K2 are a one-bit mask. - if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, And)) + if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, &And, + /* IsAnd */ true)) return V; ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); @@ -2367,7 +2362,8 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) // if K1 and K2 are a one-bit mask. - if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, Or)) + if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, &Or, + /* IsAnd */ false)) return V; ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 56204db681299d..a3eabfd7fcfadf 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -350,7 +350,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final Value *foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd); Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, - BinaryOperator &Logic); + Instruction *CxtI, bool IsAnd); Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D); Value *getSelectCondition(Value *A, Value *B);