diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index ee090e0125082..4e184fbb047d9 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1285,22 +1285,77 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, Swapped = true; } - // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand. - // Make sure Y cannot be undef though, as we might pick different values for - // undef in the icmp and in f(Y). Additionally, take care to avoid replacing - // X == Y ? X : Z with X == Y ? Y : Z, as that would lead to an infinite - // replacement cycle. Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1); - if (TrueVal != CmpLHS && - isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) { - if (Value *V = simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ, - /* AllowRefinement */ true)) - // Require either the replacement or the simplification result to be a - // constant to avoid infinite loops. - // FIXME: Make this check more precise. - if (isa(CmpRHS) || isa(V)) - return replaceOperand(Sel, Swapped ? 2 : 1, V); + auto ReplaceLHSOpWithRHSOp = [&](Value *OldOp, + Value *NewOp) -> Instruction * { + // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand. + // Take care to avoid replacing X == Y ? X : Z with X == Y ? Y : Z, as that + // would lead to an infinite replacement cycle. + // If we will be able to evaluate f(Y) to a constant, we can allow undef, + // otherwise Y cannot be undef as we might pick different values for undef + // in the icmp and in f(Y). + if (TrueVal == OldOp) + return nullptr; + + std::optional IsNeverUndefCached; + auto IsNeverUndef = [&](Value *Op) { + if (!IsNeverUndefCached.has_value()) + IsNeverUndefCached = + isGuaranteedNotToBeUndefOrPoison(Op, SQ.AC, &Sel, &DT); + return *IsNeverUndefCached; + }; + if (Value *V = simplifyWithOpReplaced(TrueVal, OldOp, NewOp, SQ, + /* AllowRefinement= */ true)) { + // Need some guarantees about the new simplified op to ensure we don't inf + // loop. + // If we simplify to a constant, replace. + bool ShouldReplace = match(V, m_ImmConstant()); + bool NeedsNoUndef = !ShouldReplace; + // Or replace if either NewOp is a constant + if (!ShouldReplace && match(NewOp, m_ImmConstant())) + ShouldReplace = true; + // Or if we end up simplifying f(Y) -> Y i.e: Old & New -> New & New -> + // New. + if (!ShouldReplace && V == NewOp) + ShouldReplace = true; + + // Finally, if we are going to create a new one-use instruction, replace. + if (!ShouldReplace && isa(OldOp) && OldOp->hasNUses(2) && + (!isa(NewOp) || !NewOp->hasOneUse())) + ShouldReplace = true; + + // Unless we simplify the new instruction to a constant, need to ensure Y + // is not undef. + if (NeedsNoUndef && ShouldReplace) + ShouldReplace = IsNeverUndef(NewOp); + + if (ShouldReplace) + return replaceOperand(Sel, Swapped ? 2 : 1, V); + } + // If we can't simplify, but we will either: + // 1) Create a new binop where both ops are NewOp i.e (add x, y) is "worse" + // than (add y, y) in this case, wait until the second call so we don't + // miss a one-use simplification. + // 2) Create a new one-use instruction. + // proceed. + if (TrueVal->hasOneUse() && + (match(TrueVal, m_c_BinOp(m_Specific(OldOp), m_Specific(NewOp))) || + (isa(TrueVal) && isa(OldOp) && + OldOp->hasNUses(2) && + (!isa(NewOp) || !NewOp->hasOneUse())))) { + auto *TrueIns = cast(TrueVal); + for (unsigned OpIdx = 0; OpIdx < TrueIns->getNumOperands(); ++OpIdx) { + if (TrueIns->getOperand(OpIdx) == OldOp) { + // Need to ensure NewOp is noundef (same reason as above). Wait until + // the last moment to do this check as it can be relatively expensive. + if (!IsNeverUndef(NewOp)) + break; + TrueIns->setOperand(OpIdx, NewOp); + return replaceOperand(Sel, Swapped ? 2 : 1, TrueIns); + } + } + } // Even if TrueVal does not simplify, we can directly replace a use of // CmpLHS with CmpRHS, as long as the instruction is not used anywhere // else and is safe to speculatively execute (we may end up executing it @@ -1308,17 +1363,19 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, // undefined behavior). Only do this if CmpRHS is a constant, as // profitability is not clear for other cases. // FIXME: Support vectors. - if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()) && - !Cmp.getType()->isVectorTy()) - if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS)) + if (OldOp == CmpLHS && match(NewOp, m_ImmConstant()) && + !match(OldOp, m_ImmConstant()) && !Cmp.getType()->isVectorTy() && + IsNeverUndef(NewOp)) + if (replaceInInstruction(TrueVal, OldOp, NewOp)) return &Sel; - } - if (TrueVal != CmpRHS && - isGuaranteedNotToBeUndefOrPoison(CmpLHS, SQ.AC, &Sel, &DT)) - if (Value *V = simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ, - /* AllowRefinement */ true)) - if (isa(CmpLHS) || isa(V)) - return replaceOperand(Sel, Swapped ? 2 : 1, V); + + return nullptr; + }; + + if (Instruction *R = ReplaceLHSOpWithRHSOp(CmpLHS, CmpRHS)) + return R; + if (Instruction *R = ReplaceLHSOpWithRHSOp(CmpRHS, CmpLHS)) + return R; auto *FalseInst = dyn_cast(FalseVal); if (!FalseInst) diff --git a/llvm/test/Transforms/InstCombine/abs-1.ll b/llvm/test/Transforms/InstCombine/abs-1.ll index 32bd7a37053ed..0cf7cd97d8ff4 100644 --- a/llvm/test/Transforms/InstCombine/abs-1.ll +++ b/llvm/test/Transforms/InstCombine/abs-1.ll @@ -852,11 +852,8 @@ define i8 @abs_diff_signed_sgt_nuw_extra_use3(i8 %a, i8 %b) { define i32 @abs_diff_signed_slt_swap_wrong_pred1(i32 %a, i32 %b) { ; CHECK-LABEL: @abs_diff_signed_slt_swap_wrong_pred1( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[SUB_BA:%.*]] = sub nsw i32 [[B]], [[A]] -; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i32 [[A]], [[B]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[SUB_BA]], i32 [[SUB_AB]] -; CHECK-NEXT: ret i32 [[COND]] +; CHECK-NEXT: [[SUB_AB:%.*]] = sub nsw i32 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: ret i32 [[SUB_AB]] ; %cmp = icmp eq i32 %a, %b %sub_ba = sub nsw i32 %b, %a diff --git a/llvm/test/Transforms/InstCombine/select-cmp-eq-op-fold.ll b/llvm/test/Transforms/InstCombine/select-cmp-eq-op-fold.ll index ec82b1944f723..d2a5c71025370 100644 --- a/llvm/test/Transforms/InstCombine/select-cmp-eq-op-fold.ll +++ b/llvm/test/Transforms/InstCombine/select-cmp-eq-op-fold.ll @@ -6,8 +6,7 @@ declare void @use.i8(i8) define i8 @replace_with_y_noundef(i8 %x, i8 noundef %y, i8 %z) { ; CHECK-LABEL: @replace_with_y_noundef( ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], [[Y]] -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[AND]], i8 [[Z:%.*]] +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[Y]], i8 [[Z:%.*]] ; CHECK-NEXT: ret i8 [[SEL]] ; %cmp = icmp eq i8 %x, %y @@ -20,8 +19,7 @@ define i8 @replace_with_x_noundef(i8 noundef %x, i8 %y, i8 %z) { ; CHECK-LABEL: @replace_with_x_noundef( ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: call void @use.i1(i1 [[CMP]]) -; CHECK-NEXT: [[AND:%.*]] = or i8 [[X]], [[Y]] -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[Z:%.*]], i8 [[AND]] +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[Z:%.*]], i8 [[X]] ; CHECK-NEXT: ret i8 [[SEL]] ; %cmp = icmp ne i8 %x, %y @@ -50,7 +48,7 @@ define i8 @replace_with_y_for_new_oneuse(i8 noundef %xx, i8 noundef %y, i8 %z) { ; CHECK-LABEL: @replace_with_y_for_new_oneuse( ; CHECK-NEXT: [[X:%.*]] = mul i8 [[XX:%.*]], 13 ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[ADD:%.*]] = add nuw i8 [[X]], [[Y]] +; CHECK-NEXT: [[ADD:%.*]] = shl nuw i8 [[Y]], 1 ; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[ADD]], i8 [[Z:%.*]] ; CHECK-NEXT: ret i8 [[SEL]] ; @@ -65,7 +63,7 @@ define i8 @replace_with_y_for_new_oneuse2(i8 %xx, i8 noundef %y, i8 %z, i8 %q) { ; CHECK-LABEL: @replace_with_y_for_new_oneuse2( ; CHECK-NEXT: [[X:%.*]] = mul i8 [[XX:%.*]], 13 ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[ADD:%.*]] = add nuw i8 [[X]], [[Q:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = add nuw i8 [[Y]], [[Q:%.*]] ; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[ADD]], i8 [[Z:%.*]] ; CHECK-NEXT: ret i8 [[SEL]] ; @@ -81,7 +79,7 @@ define i8 @replace_with_x_for_new_oneuse(i8 noundef %xx, i8 noundef %yy, i8 %z, ; CHECK-NEXT: [[X:%.*]] = mul i8 [[XX:%.*]], 13 ; CHECK-NEXT: [[Y:%.*]] = add i8 [[YY:%.*]], [[W:%.*]] ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X]], [[Y]] -; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[X]], [[Y]] +; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[X]], [[X]] ; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[MUL]], i8 [[Z:%.*]] ; CHECK-NEXT: ret i8 [[SEL]] ; @@ -115,7 +113,7 @@ define i8 @replace_with_x_for_simple_binop(i8 noundef %xx, i8 %yy, i8 %z, i8 %w) ; CHECK-NEXT: [[X:%.*]] = mul i8 [[XX:%.*]], 13 ; CHECK-NEXT: [[Y:%.*]] = add i8 [[YY:%.*]], [[W:%.*]] ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X]], [[Y]] -; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[X]], [[Y]] +; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[X]], [[X]] ; CHECK-NEXT: call void @use.i8(i8 [[Y]]) ; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[MUL]], i8 [[Z:%.*]] ; CHECK-NEXT: ret i8 [[SEL]] @@ -147,7 +145,7 @@ define i8 @replace_with_none_for_new_oneuse_fail_maybe_undef(i8 %xx, i8 %y, i8 % define i8 @replace_with_y_for_simple_binop(i8 %x, i8 noundef %y, i8 %z) { ; CHECK-LABEL: @replace_with_y_for_simple_binop( ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[MUL:%.*]] = mul nsw i8 [[X]], [[Y]] +; CHECK-NEXT: [[MUL:%.*]] = mul nsw i8 [[Y]], [[Y]] ; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[MUL]], i8 [[Z:%.*]] ; CHECK-NEXT: ret i8 [[SEL]] ; diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll index 2ade6faa99be3..bfb3fad1050ba 100644 --- a/llvm/test/Transforms/InstCombine/select.ll +++ b/llvm/test/Transforms/InstCombine/select.ll @@ -2787,8 +2787,7 @@ define <2 x i8> @select_replacement_add_eq_vec_nonuniform(<2 x i8> %x, <2 x i8> define <2 x i8> @select_replacement_add_eq_vec_poison(<2 x i8> %x, <2 x i8> %y) { ; CHECK-LABEL: @select_replacement_add_eq_vec_poison( ; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[ADD:%.*]] = add <2 x i8> [[X]], -; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[CMP]], <2 x i8> [[ADD]], <2 x i8> [[Y:%.*]] +; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[CMP]], <2 x i8> , <2 x i8> [[Y:%.*]] ; CHECK-NEXT: ret <2 x i8> [[SEL]] ; %cmp = icmp eq <2 x i8> %x, @@ -2839,8 +2838,7 @@ define i8 @select_replacement_sub_noundef(i8 %x, i8 noundef %y, i8 %z) { define i8 @select_replacement_sub(i8 %x, i8 %y, i8 %z) { ; CHECK-LABEL: @select_replacement_sub( ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[SUB:%.*]] = sub i8 [[X]], [[Y]] -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[SUB]], i8 [[Z:%.*]] +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 0, i8 [[Z:%.*]] ; CHECK-NEXT: ret i8 [[SEL]] ; %cmp = icmp eq i8 %x, %y