Skip to content

Commit

Permalink
[InstCombine] Use simplifyWithOpReplaced() for non-bool selects
Browse files Browse the repository at this point in the history
Perform the simplifyWithOpReplaced() fold even for non-bool
selects. This subsumes a number of recently added folds for
zext/sext of the condition.

We still need to manually handle variations with both sext/zext
and not, because simplifyWithOpReplaced() only performs one
level of replacements.
  • Loading branch information
nikic committed Sep 22, 2022
1 parent babdef2 commit c2e76f9
Showing 1 changed file with 21 additions and 34 deletions.
55 changes: 21 additions & 34 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2664,46 +2664,40 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
// If the type of select is not an integer type or if the condition and
// the selection type are not both scalar nor both vector types, there is no
// point in attempting to match these patterns.
Type *CondType = CondVal->getType();
if (!isa<Constant>(CondVal) && SelType->isIntOrIntVectorTy() &&
CondVal->getType()->isVectorTy() == SelType->isVectorTy()) {
auto *One = ConstantInt::get(SelType, 1);
auto *Zero = ConstantInt::get(SelType, 0);
auto *AllOnes = ConstantInt::get(SelType, -1, /*isSigned*/ true);

// select a, a, b -> select a, 1, b
// select a, zext(a), b -> select a, 1, b
if (match(TrueVal, m_ZExtOrSelf(m_Specific(CondVal))))
return replaceOperand(SI, 1, One);

// select a, sext(a), b -> select a, -1, b
if (match(TrueVal, m_SExt(m_Specific(CondVal))))
return replaceOperand(SI, 1, AllOnes);
CondType->isVectorTy() == SelType->isVectorTy()) {
if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal,
ConstantInt::getTrue(CondType), SQ,
/* AllowRefinement */ true))
return replaceOperand(SI, 1, S);

// select a, b, a -> select a, b, 0
// select a, b, zext(a) -> select a, b, 0
// select a, b, sext(a) -> select a, b, 0
if (match(FalseVal, m_ZExtOrSExtOrSelf(m_Specific(CondVal))))
return replaceOperand(SI, 2, Zero);
if (Value *S = simplifyWithOpReplaced(FalseVal, CondVal,
ConstantInt::getFalse(CondType), SQ,
/* AllowRefinement */ true))
return replaceOperand(SI, 2, S);

// Handle patterns involving sext/zext + not explicitly,
// as simplifyWithOpReplaced() only looks past one instruction.
Value *NotCond;

// select a, !a, b -> select !a, b, 0
// select a, sext(!a), b -> select !a, b, 0
// select a, zext(!a), b -> select !a, b, 0
if (match(TrueVal, m_ZExtOrSExtOrSelf(m_CombineAnd(
m_Value(NotCond), m_Not(m_Specific(CondVal))))))
return SelectInst::Create(NotCond, FalseVal, Zero);
if (match(TrueVal, m_ZExtOrSExt(m_CombineAnd(m_Value(NotCond),
m_Not(m_Specific(CondVal))))))
return SelectInst::Create(NotCond, FalseVal,
Constant::getNullValue(SelType));

// select a, b, !a -> select !a, 1, b
// select a, b, zext(!a) -> select !a, 1, b
if (match(FalseVal, m_ZExtOrSelf(m_CombineAnd(m_Value(NotCond),
m_Not(m_Specific(CondVal))))))
return SelectInst::Create(NotCond, One, TrueVal);
if (match(FalseVal, m_ZExt(m_CombineAnd(m_Value(NotCond),
m_Not(m_Specific(CondVal))))))
return SelectInst::Create(NotCond, ConstantInt::get(SelType, 1), TrueVal);

// select a, b, sext(!a) -> select !a, -1, b
if (match(FalseVal, m_SExt(m_CombineAnd(m_Value(NotCond),
m_Not(m_Specific(CondVal))))))
return SelectInst::Create(NotCond, AllOnes, TrueVal);
return SelectInst::Create(NotCond, Constant::getAllOnesValue(SelType),
TrueVal);
}

// Avoid potential infinite loops by checking for non-constant condition.
Expand Down Expand Up @@ -2794,13 +2788,6 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
return BinaryOperator::CreateAnd(FalseVal, Builder.CreateOr(C, TrueVal));
}

if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal, One, SQ,
/* AllowRefinement */ true))
return replaceOperand(SI, 1, S);
if (Value *S = simplifyWithOpReplaced(FalseVal, CondVal, Zero, SQ,
/* AllowRefinement */ true))
return replaceOperand(SI, 2, S);

if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) {
Use *Y = nullptr;
bool IsAnd = match(FalseVal, m_Zero()) ? true : false;
Expand Down

0 comments on commit c2e76f9

Please sign in to comment.