diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 3cab06079a87da..c40e5c36cdc7c8 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -2127,12 +2127,21 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, Instruction::Xor, Q, MaxRecurse)) return V; - // If the operation is with the result of a select instruction, check whether - // operating on either branch of the select always yields the same value. - if (isa(Op0) || isa(Op1)) + if (isa(Op0) || isa(Op1)) { + if (Op0->getType()->isIntOrIntVectorTy(1)) { + // A & (A && B) -> A && B + if (match(Op1, m_Select(m_Specific(Op0), m_Value(), m_Zero()))) + return Op1; + else if (match(Op0, m_Select(m_Specific(Op1), m_Value(), m_Zero()))) + return Op0; + } + // If the operation is with the result of a select instruction, check + // whether operating on either branch of the select always yields the same + // value. if (Value *V = ThreadBinOpOverSelect(Instruction::And, Op0, Op1, Q, MaxRecurse)) return V; + } // If the operation is with the result of a phi instruction, check whether // operating on all incoming values of the phi always yields the same value. @@ -2303,12 +2312,21 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, Instruction::And, Q, MaxRecurse)) return V; - // If the operation is with the result of a select instruction, check whether - // operating on either branch of the select always yields the same value. - if (isa(Op0) || isa(Op1)) + if (isa(Op0) || isa(Op1)) { + if (Op0->getType()->isIntOrIntVectorTy(1)) { + // A | (A || B) -> A || B + if (match(Op1, m_Select(m_Specific(Op0), m_One(), m_Value()))) + return Op1; + else if (match(Op0, m_Select(m_Specific(Op1), m_One(), m_Value()))) + return Op0; + } + // If the operation is with the result of a select instruction, check + // whether operating on either branch of the select always yields the same + // value. if (Value *V = ThreadBinOpOverSelect(Instruction::Or, Op0, Op1, Q, MaxRecurse)) return V; + } // (A & C1)|(B & C2) const APInt *C1, *C2; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 352126fa07ca24..59291617d2b05f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -3444,19 +3444,32 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { } } - // Pull 'not' into operands of select if both operands are one-use compares. + // Pull 'not' into operands of select if both operands are one-use compares + // or one is one-use compare and the other one is a constant. // Inverting the predicates eliminates the 'not' operation. // Example: - // not (select ?, (cmp TPred, ?, ?), (cmp FPred, ?, ?) --> + // not (select ?, (cmp TPred, ?, ?), (cmp FPred, ?, ?) --> // select ?, (cmp InvTPred, ?, ?), (cmp InvFPred, ?, ?) - // TODO: Canonicalize by hoisting 'not' into an arm of the select if only - // 1 select operand is a cmp? + // not (select ?, (cmp TPred, ?, ?), true --> + // select ?, (cmp InvTPred, ?, ?), false if (auto *Sel = dyn_cast(Op0)) { - auto *CmpT = dyn_cast(Sel->getTrueValue()); - auto *CmpF = dyn_cast(Sel->getFalseValue()); - if (CmpT && CmpF && CmpT->hasOneUse() && CmpF->hasOneUse()) { - CmpT->setPredicate(CmpT->getInversePredicate()); - CmpF->setPredicate(CmpF->getInversePredicate()); + Value *TV = Sel->getTrueValue(); + Value *FV = Sel->getFalseValue(); + auto *CmpT = dyn_cast(TV); + auto *CmpF = dyn_cast(FV); + bool InvertibleT = (CmpT && CmpT->hasOneUse()) || isa(TV); + bool InvertibleF = (CmpF && CmpF->hasOneUse()) || isa(FV); + if (InvertibleT && InvertibleF) { + Constant *One = cast(Op1); + + if (CmpT) + CmpT->setPredicate(CmpT->getInversePredicate()); + else + Sel->setTrueValue(ConstantExpr::getNot(cast(TV))); + if (CmpF) + CmpF->setPredicate(CmpF->getInversePredicate()); + else + Sel->setFalseValue(ConstantExpr::getNot(cast(FV))); return replaceInstUsesWith(I, Sel); } } diff --git a/llvm/test/Transforms/InstCombine/select-safe-transforms.ll b/llvm/test/Transforms/InstCombine/select-safe-transforms.ll index 48235863d9ff7e..35f100302d47b2 100644 --- a/llvm/test/Transforms/InstCombine/select-safe-transforms.ll +++ b/llvm/test/Transforms/InstCombine/select-safe-transforms.ll @@ -56,8 +56,7 @@ define i1 @cond_eq_or_const(i8 %X, i8 %Y) { define i1 @merge_and(i1 %X, i1 %Y) { ; CHECK-LABEL: @merge_and( ; CHECK-NEXT: [[C:%.*]] = select i1 [[X:%.*]], i1 [[Y:%.*]], i1 false -; CHECK-NEXT: [[RES:%.*]] = and i1 [[C]], [[X]] -; CHECK-NEXT: ret i1 [[RES]] +; CHECK-NEXT: ret i1 [[C]] ; %c = select i1 %X, i1 %Y, i1 false %res = and i1 %X, %c @@ -67,8 +66,7 @@ define i1 @merge_and(i1 %X, i1 %Y) { define i1 @merge_or(i1 %X, i1 %Y) { ; CHECK-LABEL: @merge_or( ; CHECK-NEXT: [[C:%.*]] = select i1 [[X:%.*]], i1 true, i1 [[Y:%.*]] -; CHECK-NEXT: [[RES:%.*]] = or i1 [[C]], [[X]] -; CHECK-NEXT: ret i1 [[RES]] +; CHECK-NEXT: ret i1 [[C]] ; %c = select i1 %X, i1 true, i1 %Y %res = or i1 %X, %c @@ -77,10 +75,10 @@ define i1 @merge_or(i1 %X, i1 %Y) { define i1 @xor_and(i1 %c, i32 %X, i32 %Y) { ; CHECK-LABEL: @xor_and( -; CHECK-NEXT: [[COMP:%.*]] = icmp ult i32 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C:%.*]], i1 [[COMP]], i1 false -; CHECK-NEXT: [[RES:%.*]] = xor i1 [[SEL]], true -; CHECK-NEXT: ret i1 [[RES]] +; CHECK-NEXT: [[COMP:%.*]] = icmp uge i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[NOT_C:%.*]] = xor i1 [[C:%.*]], true +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[NOT_C]], i1 true, i1 [[COMP]] +; CHECK-NEXT: ret i1 [[SEL]] ; %comp = icmp ult i32 %X, %Y %sel = select i1 %c, i1 %comp, i1 false @@ -90,10 +88,9 @@ define i1 @xor_and(i1 %c, i32 %X, i32 %Y) { define <2 x i1> @xor_and2(<2 x i1> %c, <2 x i32> %X, <2 x i32> %Y) { ; CHECK-LABEL: @xor_and2( -; CHECK-NEXT: [[COMP:%.*]] = icmp ult <2 x i32> [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[C:%.*]], <2 x i1> [[COMP]], <2 x i1> -; CHECK-NEXT: [[RES:%.*]] = xor <2 x i1> [[SEL]], -; CHECK-NEXT: ret <2 x i1> [[RES]] +; CHECK-NEXT: [[COMP:%.*]] = icmp uge <2 x i32> [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[C:%.*]], <2 x i1> [[COMP]], <2 x i1> +; CHECK-NEXT: ret <2 x i1> [[SEL]] ; %comp = icmp ult <2 x i32> %X, %Y %sel = select <2 x i1> %c, <2 x i1> %comp, <2 x i1> @@ -105,10 +102,9 @@ define <2 x i1> @xor_and2(<2 x i1> %c, <2 x i32> %X, <2 x i32> %Y) { define <2 x i1> @xor_and3(<2 x i1> %c, <2 x i32> %X, <2 x i32> %Y) { ; CHECK-LABEL: @xor_and3( -; CHECK-NEXT: [[COMP:%.*]] = icmp ult <2 x i32> [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[C:%.*]], <2 x i1> [[COMP]], <2 x i1> -; CHECK-NEXT: [[RES:%.*]] = xor <2 x i1> [[SEL]], -; CHECK-NEXT: ret <2 x i1> [[RES]] +; CHECK-NEXT: [[COMP:%.*]] = icmp uge <2 x i32> [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[C:%.*]], <2 x i1> [[COMP]], <2 x i1> +; CHECK-NEXT: ret <2 x i1> [[SEL]] ; %comp = icmp ult <2 x i32> %X, %Y %sel = select <2 x i1> %c, <2 x i1> %comp, <2 x i1> @@ -118,10 +114,10 @@ define <2 x i1> @xor_and3(<2 x i1> %c, <2 x i32> %X, <2 x i32> %Y) { define i1 @xor_or(i1 %c, i32 %X, i32 %Y) { ; CHECK-LABEL: @xor_or( -; CHECK-NEXT: [[COMP:%.*]] = icmp ult i32 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C:%.*]], i1 true, i1 [[COMP]] -; CHECK-NEXT: [[RES:%.*]] = xor i1 [[SEL]], true -; CHECK-NEXT: ret i1 [[RES]] +; CHECK-NEXT: [[COMP:%.*]] = icmp uge i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[NOT_C:%.*]] = xor i1 [[C:%.*]], true +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[NOT_C]], i1 [[COMP]], i1 false +; CHECK-NEXT: ret i1 [[SEL]] ; %comp = icmp ult i32 %X, %Y %sel = select i1 %c, i1 true, i1 %comp @@ -131,10 +127,9 @@ define i1 @xor_or(i1 %c, i32 %X, i32 %Y) { define <2 x i1> @xor_or2(<2 x i1> %c, <2 x i32> %X, <2 x i32> %Y) { ; CHECK-LABEL: @xor_or2( -; CHECK-NEXT: [[COMP:%.*]] = icmp ult <2 x i32> [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[C:%.*]], <2 x i1> , <2 x i1> [[COMP]] -; CHECK-NEXT: [[RES:%.*]] = xor <2 x i1> [[SEL]], -; CHECK-NEXT: ret <2 x i1> [[RES]] +; CHECK-NEXT: [[COMP:%.*]] = icmp uge <2 x i32> [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[C:%.*]], <2 x i1> , <2 x i1> [[COMP]] +; CHECK-NEXT: ret <2 x i1> [[SEL]] ; %comp = icmp ult <2 x i32> %X, %Y %sel = select <2 x i1> %c, <2 x i1> , <2 x i1> %comp @@ -144,10 +139,9 @@ define <2 x i1> @xor_or2(<2 x i1> %c, <2 x i32> %X, <2 x i32> %Y) { define <2 x i1> @xor_or3(<2 x i1> %c, <2 x i32> %X, <2 x i32> %Y) { ; CHECK-LABEL: @xor_or3( -; CHECK-NEXT: [[COMP:%.*]] = icmp ult <2 x i32> [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[C:%.*]], <2 x i1> , <2 x i1> [[COMP]] -; CHECK-NEXT: [[RES:%.*]] = xor <2 x i1> [[SEL]], -; CHECK-NEXT: ret <2 x i1> [[RES]] +; CHECK-NEXT: [[COMP:%.*]] = icmp uge <2 x i32> [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[C:%.*]], <2 x i1> , <2 x i1> [[COMP]] +; CHECK-NEXT: ret <2 x i1> [[SEL]] ; %comp = icmp ult <2 x i32> %X, %Y %sel = select <2 x i1> %c, <2 x i1> , <2 x i1> %comp