diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index 284312eaf82207..fb7a005708e56d 100644 --- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -607,6 +607,12 @@ static bool isNonNegative(Value *V, LazyValueInfo *LVI, Instruction *CxtI) { return Result == LazyValueInfo::True; } +static bool isNonPositive(Value *V, LazyValueInfo *LVI, Instruction *CxtI) { + Constant *Zero = ConstantInt::get(V->getType(), 0); + auto Result = LVI->getPredicateAt(ICmpInst::ICMP_SLE, V, Zero, CxtI); + return Result == LazyValueInfo::True; +} + static bool allOperandsAreNonNegative(BinaryOperator *SDI, LazyValueInfo *LVI) { return all_of(SDI->operands(), [&](Value *Op) { return isNonNegative(Op, LVI, SDI); }); @@ -672,24 +678,65 @@ static bool processSRem(BinaryOperator *SDI, LazyValueInfo *LVI) { } /// See if LazyValueInfo's ability to exploit edge conditions or range -/// information is sufficient to prove the both operands of this SDiv are -/// positive. If this is the case, replace the SDiv with a UDiv. Even for local +/// information is sufficient to prove the signs of both operands of this SDiv. +/// If this is the case, replace the SDiv with a UDiv. Even for local /// conditions, this can sometimes prove conditions instcombine can't by /// exploiting range information. static bool processSDiv(BinaryOperator *SDI, LazyValueInfo *LVI) { - if (SDI->getType()->isVectorTy() || !allOperandsAreNonNegative(SDI, LVI)) + if (SDI->getType()->isVectorTy()) return false; + enum class Domain { NonNegative, NonPositive, Unknown }; + auto getDomain = [&](Value *V) { + if (isNonNegative(V, LVI, SDI)) + return Domain::NonNegative; + if (isNonPositive(V, LVI, SDI)) + return Domain::NonPositive; + return Domain::Unknown; + }; + + struct Operand { + Value *V; + Domain Domain; + }; + std::array Ops; + for (const auto &I : zip(Ops, SDI->operands())) { + Operand &Op = std::get<0>(I); + Op.V = std::get<1>(I); + Op.Domain = getDomain(Op.V); + if (Op.Domain == Domain::Unknown) + return false; + } + + // We know domains of both of the operands! ++NumSDivs; - auto *BO = BinaryOperator::CreateUDiv(SDI->getOperand(0), SDI->getOperand(1), - SDI->getName(), SDI); - BO->setDebugLoc(SDI->getDebugLoc()); - BO->setIsExact(SDI->isExact()); - SDI->replaceAllUsesWith(BO); + + // We need operands to be non-negative, so negate each one that isn't. + for (Operand &Op : Ops) { + if (Op.Domain == Domain::NonNegative) + continue; + auto *BO = + BinaryOperator::CreateNeg(Op.V, Op.V->getName() + ".nonneg", SDI); + BO->setDebugLoc(SDI->getDebugLoc()); + Op.V = BO; + } + + auto *UDiv = + BinaryOperator::CreateUDiv(Ops[0].V, Ops[1].V, SDI->getName(), SDI); + UDiv->setDebugLoc(SDI->getDebugLoc()); + UDiv->setIsExact(SDI->isExact()); + + Value *Res = UDiv; + + // If the operands had two different domains, we need to negate the result. + if (Ops[0].Domain != Ops[1].Domain) + Res = BinaryOperator::CreateNeg(Res, Res->getName() + ".neg", SDI); + + SDI->replaceAllUsesWith(Res); SDI->eraseFromParent(); // Try to simplify our new udiv. - processUDivOrURem(BO, LVI); + processUDivOrURem(UDiv, LVI); return true; } diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/sdiv.ll b/llvm/test/Transforms/CorrelatedValuePropagation/sdiv.ll index ec5de0010a14f0..8dfa09d4779255 100644 --- a/llvm/test/Transforms/CorrelatedValuePropagation/sdiv.ll +++ b/llvm/test/Transforms/CorrelatedValuePropagation/sdiv.ll @@ -177,8 +177,10 @@ define i32 @test7_pos_neg(i32 %x, i32 %y) { ; CHECK-NEXT: call void @llvm.assume(i1 [[C0]]) ; CHECK-NEXT: [[C1:%.*]] = icmp sle i32 [[Y:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[C1]]) -; CHECK-NEXT: [[DIV:%.*]] = sdiv i32 [[X]], [[Y]] -; CHECK-NEXT: ret i32 [[DIV]] +; CHECK-NEXT: [[Y_NONNEG:%.*]] = sub i32 0, [[Y]] +; CHECK-NEXT: [[DIV1:%.*]] = udiv i32 [[X]], [[Y_NONNEG]] +; CHECK-NEXT: [[DIV1_NEG:%.*]] = sub i32 0, [[DIV1]] +; CHECK-NEXT: ret i32 [[DIV1_NEG]] ; %c0 = icmp sge i32 %x, 0 call void @llvm.assume(i1 %c0) @@ -194,8 +196,10 @@ define i32 @test8_neg_pos(i32 %x, i32 %y) { ; CHECK-NEXT: call void @llvm.assume(i1 [[C0]]) ; CHECK-NEXT: [[C1:%.*]] = icmp sge i32 [[Y:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[C1]]) -; CHECK-NEXT: [[DIV:%.*]] = sdiv i32 [[X]], [[Y]] -; CHECK-NEXT: ret i32 [[DIV]] +; CHECK-NEXT: [[X_NONNEG:%.*]] = sub i32 0, [[X]] +; CHECK-NEXT: [[DIV1:%.*]] = udiv i32 [[X_NONNEG]], [[Y]] +; CHECK-NEXT: [[DIV1_NEG:%.*]] = sub i32 0, [[DIV1]] +; CHECK-NEXT: ret i32 [[DIV1_NEG]] ; %c0 = icmp sle i32 %x, 0 call void @llvm.assume(i1 %c0) @@ -211,8 +215,10 @@ define i32 @test9_neg_neg(i32 %x, i32 %y) { ; CHECK-NEXT: call void @llvm.assume(i1 [[C0]]) ; CHECK-NEXT: [[C1:%.*]] = icmp sle i32 [[Y:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[C1]]) -; CHECK-NEXT: [[DIV:%.*]] = sdiv i32 [[X]], [[Y]] -; CHECK-NEXT: ret i32 [[DIV]] +; CHECK-NEXT: [[X_NONNEG:%.*]] = sub i32 0, [[X]] +; CHECK-NEXT: [[Y_NONNEG:%.*]] = sub i32 0, [[Y]] +; CHECK-NEXT: [[DIV1:%.*]] = udiv i32 [[X_NONNEG]], [[Y_NONNEG]] +; CHECK-NEXT: ret i32 [[DIV1]] ; %c0 = icmp sle i32 %x, 0 call void @llvm.assume(i1 %c0)