Skip to content

[InstCombine] Add helper simplifying Instruction w/ constants with eq/ne Constants #86346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,24 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
PatternMatch::m_Value()));
}

/// Assumes that we have `Op eq/ne Vals` (either icmp or switch). Will try to
/// constant fold `Vals` so that we can use `Op' eq/ne Vals'`. For example if
/// we have `Op` as `add X, C0`, it will simplify all `Vals` as `Vals[i] - C0`
/// and return `X`.
Value *simplifyOpWithConstantEqConsts(Value *Op, BuilderTy &Builder,
SmallVector<Constant *> &Vals,
bool ReqOneUseAdd = true);

Value *simplifyOpWithConstantEqConsts(Value *Op, BuilderTy &Builder,
Constant *&Val,
bool ReqOneUseAdd = true) {
SmallVector<Constant *> CVals;
CVals.push_back(Val);
Value *R = simplifyOpWithConstantEqConsts(Op, Builder, CVals, ReqOneUseAdd);
Val = CVals[0];
return R;
}

/// Return nonnull value if V is free to invert under the condition of
/// WillInvertAllUses.
/// If Builder is nonnull, it will return a simplified ~V.
Expand Down
32 changes: 8 additions & 24 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3572,16 +3572,6 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant(
return new ICmpInst(Pred, II->getArgOperand(0), ConstantInt::get(Ty, C));
break;

case Intrinsic::bswap:
// bswap(A) == C -> A == bswap(C)
return new ICmpInst(Pred, II->getArgOperand(0),
ConstantInt::get(Ty, C.byteSwap()));

case Intrinsic::bitreverse:
// bitreverse(A) == C -> A == bitreverse(C)
return new ICmpInst(Pred, II->getArgOperand(0),
ConstantInt::get(Ty, C.reverseBits()));

case Intrinsic::ctlz:
case Intrinsic::cttz: {
// ctz(A) == bitwidth(A) -> A == 0 and likewise for !=
Expand Down Expand Up @@ -3618,20 +3608,6 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant(
break;
}

case Intrinsic::fshl:
case Intrinsic::fshr:
if (II->getArgOperand(0) == II->getArgOperand(1)) {
const APInt *RotAmtC;
// ror(X, RotAmtC) == C --> X == rol(C, RotAmtC)
// rol(X, RotAmtC) == C --> X == ror(C, RotAmtC)
if (match(II->getArgOperand(2), m_APInt(RotAmtC)))
return new ICmpInst(Pred, II->getArgOperand(0),
II->getIntrinsicID() == Intrinsic::fshl
? ConstantInt::get(Ty, C.rotr(*RotAmtC))
: ConstantInt::get(Ty, C.rotl(*RotAmtC)));
}
break;

case Intrinsic::umax:
case Intrinsic::uadd_sat: {
// uadd.sat(a, b) == 0 -> (a | b) == 0
Expand Down Expand Up @@ -5456,6 +5432,14 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {

Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
const CmpInst::Predicate Pred = I.getPredicate();
{
Constant *C;
if (match(Op1, m_ImmConstant(C))) {
if (auto *R = simplifyOpWithConstantEqConsts(Op0, Builder, C))
return new ICmpInst(Pred, R, C);
}
}

Value *A, *B, *C, *D;
if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) {
if (A == Op1 || B == Op1) { // (A^B) == A -> B == 0
Expand Down
290 changes: 230 additions & 60 deletions llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3572,78 +3572,248 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
return nullptr;
}

Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
Value *Cond = SI.getCondition();
Value *Op0;
ConstantInt *AddRHS;
if (match(Cond, m_Add(m_Value(Op0), m_ConstantInt(AddRHS)))) {
// Change 'switch (X+4) case 1:' into 'switch (X) case -3'.
for (auto Case : SI.cases()) {
Constant *NewCase = ConstantExpr::getSub(Case.getCaseValue(), AddRHS);
assert(isa<ConstantInt>(NewCase) &&
"Result of expression should be constant");
Case.setValue(cast<ConstantInt>(NewCase));
Value *
InstCombiner::simplifyOpWithConstantEqConsts(Value *Op, BuilderTy &Builder,
SmallVector<Constant *> &Vals,
bool ReqOneUseAdd) {

Operator *I = dyn_cast<Operator>(Op);
if (!I)
return nullptr;

auto ReverseAll = [&](function_ref<Constant *(Constant *)> ReverseF) {
for (size_t i = 0, e = Vals.size(); i < e; ++i) {
Vals[i] = ReverseF(Vals[i]);
}
return replaceOperand(SI, 0, Op0);
};

SmallVector<const APInt *, 4> ValsAsAPInt;
for (Constant *C : Vals) {
const APInt *CAPInt;
if (!match(C, m_APInt(CAPInt)))
break;
ValsAsAPInt.push_back(CAPInt);
}
bool UseAPInt = ValsAsAPInt.size() == Vals.size();

ConstantInt *SubLHS;
if (match(Cond, m_Sub(m_ConstantInt(SubLHS), m_Value(Op0)))) {
// Change 'switch (1-X) case 1:' into 'switch (X) case 0'.
for (auto Case : SI.cases()) {
Constant *NewCase = ConstantExpr::getSub(SubLHS, Case.getCaseValue());
assert(isa<ConstantInt>(NewCase) &&
"Result of expression should be constant");
Case.setValue(cast<ConstantInt>(NewCase));
auto ReverseAllAPInt = [&](function_ref<APInt(const APInt *)> ReverseF) {
assert(UseAPInt && "Can't reverse non-apint constants!");
for (size_t i = 0, e = Vals.size(); i < e; ++i) {
Vals[i] = ConstantInt::get(Vals[i]->getType(), ReverseF(ValsAsAPInt[i]));
}
return replaceOperand(SI, 0, Op0);
}

uint64_t ShiftAmt;
if (match(Cond, m_Shl(m_Value(Op0), m_ConstantInt(ShiftAmt))) &&
ShiftAmt < Op0->getType()->getScalarSizeInBits() &&
all_of(SI.cases(), [&](const auto &Case) {
return Case.getCaseValue()->getValue().countr_zero() >= ShiftAmt;
})) {
// Change 'switch (X << 2) case 4:' into 'switch (X) case 1:'.
OverflowingBinaryOperator *Shl = cast<OverflowingBinaryOperator>(Cond);
if (Shl->hasNoUnsignedWrap() || Shl->hasNoSignedWrap() ||
Shl->hasOneUse()) {
Value *NewCond = Op0;
if (!Shl->hasNoUnsignedWrap() && !Shl->hasNoSignedWrap()) {
// If the shift may wrap, we need to mask off the shifted bits.
unsigned BitWidth = Op0->getType()->getScalarSizeInBits();
NewCond = Builder.CreateAnd(
Op0, APInt::getLowBitsSet(BitWidth, BitWidth - ShiftAmt));
}
for (auto Case : SI.cases()) {
const APInt &CaseVal = Case.getCaseValue()->getValue();
APInt ShiftedCase = Shl->hasNoSignedWrap() ? CaseVal.ashr(ShiftAmt)
: CaseVal.lshr(ShiftAmt);
Case.setValue(ConstantInt::get(SI.getContext(), ShiftedCase));
}
return replaceOperand(SI, 0, NewCond);
};

Constant *C;
switch (I->getOpcode()) {
default:
break;
case Instruction::Or:
if (!match(I, m_DisjointOr(m_Value(), m_Value())))
break;
// Can treat `or disjoint` as add
[[fallthrough]];
case Instruction::Add:
// We get some regressions if we drop the OneUse for add in some cases.
// See discussion in D58633.
if (ReqOneUseAdd && !I->hasOneUse())
break;
if (!match(I->getOperand(1), m_ImmConstant(C)))
break;
// X + C0 == C1 -> X == C1 - C0
ReverseAll([&](Constant *Val) { return ConstantExpr::getSub(Val, C); });
return I->getOperand(0);
case Instruction::Sub:
if (!match(I->getOperand(0), m_ImmConstant(C)))
break;
// C0 - X == C1 -> X == C0 - C1
ReverseAll([&](Constant *Val) { return ConstantExpr::getSub(C, Val); });
return I->getOperand(1);
case Instruction::Xor:
if (!match(I->getOperand(1), m_ImmConstant(C)))
break;
// X ^ C0 == C1 -> X == C1 ^ C0
ReverseAll([&](Constant *Val) { return ConstantExpr::getXor(Val, C); });
return I->getOperand(0);
case Instruction::Mul: {
const APInt *MC;
if (!UseAPInt || !match(I->getOperand(1), m_APInt(MC)) || MC->isZero())
break;
OverflowingBinaryOperator *Mul = cast<OverflowingBinaryOperator>(I);
if (!Mul->hasNoUnsignedWrap())
break;

// X nuw C0 == C1 -> X == C1 u/ C0 iff C1 u% C0 == 0
if (all_of(ValsAsAPInt,
[&](const APInt * AC) { return AC->urem(*MC).isZero(); })) {
ReverseAllAPInt([&](const APInt *Val) { return Val->udiv(*MC); });
return I->getOperand(0);
}

// X nuw C0 == C1 -> X == C1 s/ C0 iff C1 s% C0 == 0
if (all_of(ValsAsAPInt, [&](const APInt * AC) {
return (!AC->isMinSignedValue() || !MC->isAllOnes()) &&
AC->srem(*MC).isZero();
})) {
ReverseAllAPInt([&](const APInt *Val) { return Val->sdiv(*MC); });
return I->getOperand(0);
}
break;
}
case Instruction::UDiv:
case Instruction::SDiv: {
const APInt *DC;
if (!UseAPInt)
break;
if (!UseAPInt || !match(I->getOperand(1), m_APInt(DC)))
break;
if (!cast<PossiblyExactOperator>(Op)->isExact())
break;
// X u/ C0 == C1 -> X == C0 * C1 iff C0 * C1 is nuw
// X s/ C0 == C1 -> X == C0 * C1 iff C0 * C1 is nsw
if (!all_of(ValsAsAPInt, [&](const APInt *AC) {
bool Ov;
(void)(I->getOpcode() == Instruction::UDiv ? DC->umul_ov(*AC, Ov)
: DC->smul_ov(*AC, Ov));
return !Ov;
}))
break;

// Fold switch(zext/sext(X)) into switch(X) if possible.
if (match(Cond, m_ZExtOrSExt(m_Value(Op0)))) {
bool IsZExt = isa<ZExtInst>(Cond);
Type *SrcTy = Op0->getType();
ReverseAllAPInt([&](const APInt *Val) { return (*Val) * (*DC); });
return I->getOperand(0);
}
case Instruction::ZExt:
case Instruction::SExt: {
if (!UseAPInt)
break;
bool IsZExt = isa<ZExtInst>(I);
Type *SrcTy = I->getOperand(0)->getType();
unsigned NewWidth = SrcTy->getScalarSizeInBits();
// zext(X) == C1 -> X == trunc C1 iff zext(trunc(C1)) == C1
// sext(X) == C1 -> X == trunc C1 iff sext(trunc(C1)) == C1
if (!all_of(ValsAsAPInt, [&](const APInt *AC) {
return IsZExt ? AC->isIntN(NewWidth) : AC->isSignedIntN(NewWidth);
}))
break;

if (all_of(SI.cases(), [&](const auto &Case) {
const APInt &CaseVal = Case.getCaseValue()->getValue();
return IsZExt ? CaseVal.isIntN(NewWidth)
: CaseVal.isSignedIntN(NewWidth);
})) {
for (auto &Case : SI.cases()) {
APInt TruncatedCase = Case.getCaseValue()->getValue().trunc(NewWidth);
Case.setValue(ConstantInt::get(SI.getContext(), TruncatedCase));
for (size_t i = 0, e = Vals.size(); i < e; ++i) {
Vals[i] = ConstantInt::get(SrcTy, ValsAsAPInt[i]->trunc(NewWidth));
}
return I->getOperand(0);
}
case Instruction::Shl:
case Instruction::LShr:
case Instruction::AShr: {
if (!UseAPInt)
break;
uint64_t ShAmtC;
if (!match(I->getOperand(1), m_ConstantInt(ShAmtC)))
break;
if (ShAmtC >= I->getType()->getScalarSizeInBits())
break;

// X << C0 == C1 -> X == C1 >> C0 iff C1 >> C0 is exact
// X u>> C0 == C1 -> X == C1 << C0 iff C1 << C0 is nuw
// X s>> C0 == C1 -> X == C1 << C0 iff C1 << C0 is nsw
if (!all_of(ValsAsAPInt, [&](const APInt *AC) {
switch (I->getOpcode()) {
case Instruction::Shl:
return AC->countr_zero() >= ShAmtC;
case Instruction::LShr:
return AC->countl_zero() >= ShAmtC;
case Instruction::AShr:
return AC->getNumSignBits() >= ShAmtC;
return false;
default:
llvm_unreachable("Already checked Opcode");
}
}))
break;

bool HasExact = false, HasNUW = false, HasNSW = false;
if (I->getOpcode() == Instruction::Shl) {
OverflowingBinaryOperator *Shl = cast<OverflowingBinaryOperator>(I);
HasNUW = Shl->hasNoUnsignedWrap();
HasNSW = Shl->hasNoSignedWrap();
} else {
HasExact = cast<PossiblyExactOperator>(Op)->isExact();
}

Value *R = I->getOperand(0);
if (!HasExact && !HasNUW && !HasNSW) {
if (!I->hasOneUse())
break;

// We may be shifting out 1s from X, so need to mask it.
unsigned BitWidth = R->getType()->getScalarSizeInBits();
R = Builder.CreateAnd(
R, I->getOpcode() == Instruction::Shl
? APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC)
: APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC));
}

ReverseAllAPInt([&](const APInt *Val) {
if (I->getOpcode() == Instruction::Shl)
return HasNSW ? Val->ashr(ShAmtC) : Val->lshr(ShAmtC);
return Val->shl(ShAmtC);
});
return R;
}
case Instruction::Call:
case Instruction::Invoke: {
if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
switch (II->getIntrinsicID()) {
default:
break;
case Intrinsic::bitreverse:
if (!UseAPInt)
break;
// bitreverse(X) == C -> X == bitreverse(C)
ReverseAllAPInt([&](const APInt *Val) { return Val->reverseBits(); });
return II->getArgOperand(0);
case Intrinsic::bswap:
if (!UseAPInt)
break;
// bswap(X) == C -> X == bswap(C)
ReverseAllAPInt([&](const APInt *Val) { return Val->byteSwap(); });
return II->getArgOperand(0);
case Intrinsic::fshr:
case Intrinsic::fshl: {
if (!UseAPInt)
break;
if (II->getArgOperand(0) != II->getArgOperand(1))
break;
const APInt *RotAmtC;
if (!match(II->getArgOperand(2), m_APInt(RotAmtC)))
break;
// rol(X, C0) == C1 -> X == ror(C0, C1)
// ror(X, C0) == C1 -> X == rol(C0, C1)
ReverseAllAPInt([&](const APInt *Val) {
return II->getIntrinsicID() == Intrinsic::fshl ? Val->rotr(*RotAmtC)
: Val->rotl(*RotAmtC);
});
return II->getArgOperand(0);
}
}
return replaceOperand(SI, 0, Op0);
}
}
}
return nullptr;
}

Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
Value *Cond = SI.getCondition();

SmallVector<Constant *> CaseVals;
for (const auto &Case : SI.cases())
CaseVals.push_back(Case.getCaseValue());

if (auto *R = simplifyOpWithConstantEqConsts(Cond, Builder, CaseVals,
/*ReqOneUseAdd=*/false)) {
unsigned i = 0;
for (auto &Case : SI.cases())
Case.setValue(cast<ConstantInt>(CaseVals[i++]));
return replaceOperand(SI, 0, R);
}

KnownBits Known = computeKnownBits(Cond, 0, &SI);
unsigned LeadingKnownZeros = Known.countMinLeadingZeros();
Expand Down
Loading