diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index a908349eaff141..b6d0bed808d3f4 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -1033,9 +1033,14 @@ class Negator final { using BuilderTy = IRBuilder; BuilderTy Builder; + const DataLayout &DL; + AssumptionCache &AC; + const DominatorTree &DT; + const bool IsTrulyNegation; - Negator(LLVMContext &C, const DataLayout &DL, bool IsTrulyNegation); + Negator(LLVMContext &C, const DataLayout &DL, AssumptionCache &AC, + const DominatorTree &DT, bool IsTrulyNegation); #if LLVM_ENABLE_STATS unsigned NumValuesVisitedInThisNegator = 0; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp index 42bb748cc28720..c393a6373f7ac7 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp @@ -46,6 +46,13 @@ #include #include +namespace llvm { +class AssumptionCache; +class DataLayout; +class DominatorTree; +class LLVMContext; +} // namespace llvm + using namespace llvm; #define DEBUG_TYPE "instcombine" @@ -87,13 +94,14 @@ static cl::opt cl::desc("What is the maximal lookup depth when trying to " "check for viability of negation sinking.")); -Negator::Negator(LLVMContext &C, const DataLayout &DL, bool IsTrulyNegation_) - : Builder(C, TargetFolder(DL), +Negator::Negator(LLVMContext &C, const DataLayout &DL_, AssumptionCache &AC_, + const DominatorTree &DT_, bool IsTrulyNegation_) + : Builder(C, TargetFolder(DL_), IRBuilderCallbackInserter([&](Instruction *I) { ++NegatorNumInstructionsCreatedTotal; NewInstructions.push_back(I); })), - IsTrulyNegation(IsTrulyNegation_) {} + DL(DL_), AC(AC_), DT(DT_), IsTrulyNegation(IsTrulyNegation_) {} #if LLVM_ENABLE_STATS Negator::~Negator() { @@ -301,6 +309,16 @@ LLVM_NODISCARD Value *Negator::visit(Value *V, unsigned Depth) { return nullptr; return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg"); } + case Instruction::Or: + if (!haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), DL, &AC, I, + &DT)) + return nullptr; // Don't know how to handle `or` in general. + // `or`/`add` are interchangeable when operands have no common bits set. + // `inc` is always negatible. + if (match(I->getOperand(1), m_One())) + return Builder.CreateNot(I->getOperand(0), I->getName() + ".neg"); + // Else, just defer to Instruction::Add handling. + LLVM_FALLTHROUGH; case Instruction::Add: { // `add` is negatible if both of its operands are negatible. Value *NegOp0 = visit(I->getOperand(0), Depth + 1); @@ -364,7 +382,8 @@ LLVM_NODISCARD Value *Negator::Negate(bool LHSIsZero, Value *Root, if (!NegatorEnabled || !DebugCounter::shouldExecute(NegatorCounter)) return nullptr; - Negator N(Root->getContext(), IC.getDataLayout(), LHSIsZero); + Negator N(Root->getContext(), IC.getDataLayout(), IC.getAssumptionCache(), + IC.getDominatorTree(), LHSIsZero); Optional Res = N.run(Root); if (!Res) { // Negation failed. LLVM_DEBUG(dbgs() << "Negator: failed to sink negation into " << *Root diff --git a/llvm/test/Transforms/InstCombine/sub-of-negatible.ll b/llvm/test/Transforms/InstCombine/sub-of-negatible.ll index e22274be380bd1..0f2e6336a73e4e 100644 --- a/llvm/test/Transforms/InstCombine/sub-of-negatible.ll +++ b/llvm/test/Transforms/InstCombine/sub-of-negatible.ll @@ -827,8 +827,8 @@ nonneg_bb: define i8 @negation_of_increment_via_or_with_no_common_bits_set(i8 %x, i8 %y) { ; CHECK-LABEL: @negation_of_increment_via_or_with_no_common_bits_set( ; CHECK-NEXT: [[T0:%.*]] = shl i8 [[Y:%.*]], 1 -; CHECK-NEXT: [[T1:%.*]] = or i8 [[T0]], 1 -; CHECK-NEXT: [[T2:%.*]] = sub i8 [[X:%.*]], [[T1]] +; CHECK-NEXT: [[T1_NEG:%.*]] = xor i8 [[T0]], -1 +; CHECK-NEXT: [[T2:%.*]] = add i8 [[T1_NEG]], [[X:%.*]] ; CHECK-NEXT: ret i8 [[T2]] ; %t0 = shl i8 %y, 1 @@ -868,9 +868,9 @@ define i8 @add_via_or_with_no_common_bits_set(i8 %x, i8 %y) { ; CHECK-LABEL: @add_via_or_with_no_common_bits_set( ; CHECK-NEXT: [[T0:%.*]] = sub i8 0, [[Y:%.*]] ; CHECK-NEXT: call void @use8(i8 [[T0]]) -; CHECK-NEXT: [[T1:%.*]] = shl i8 [[T0]], 2 -; CHECK-NEXT: [[T2:%.*]] = or i8 [[T1]], 3 -; CHECK-NEXT: [[T3:%.*]] = sub i8 [[X:%.*]], [[T2]] +; CHECK-NEXT: [[T1_NEG:%.*]] = shl i8 [[Y]], 2 +; CHECK-NEXT: [[T2_NEG:%.*]] = add i8 [[T1_NEG]], -3 +; CHECK-NEXT: [[T3:%.*]] = add i8 [[T2_NEG]], [[X:%.*]] ; CHECK-NEXT: ret i8 [[T3]] ; %t0 = sub i8 0, %y