Skip to content

Commit

Permalink
[SCEV] Apply NSW and NUW flags via poison value analysis for sub, mul…
Browse files Browse the repository at this point in the history
… and shl

Summary:
http://reviews.llvm.org/D11212 made Scalar Evolution able to propagate NSW and NUW flags from instructions to SCEVs for add instructions. This patch expands that to sub, mul and shl instructions.

This change makes LSR able to generate pointer induction variables for loops like these, where the index is 32 bit and the pointer is 64 bit:

  for (int i = 0; i < numIterations; ++i)
    sum += ptr[i - offset];

  for (int i = 0; i < numIterations; ++i)
    sum += ptr[i * stride];

  for (int i = 0; i < numIterations; ++i)
    sum += ptr[3 * (i << 7)];


Reviewers: atrick, sanjoy

Subscribers: sanjoy, majnemer, hfinkel, llvm-commits, meheff, jingyue, eliben

Differential Revision: http://reviews.llvm.org/D11860

llvm-svn: 245118
  • Loading branch information
broune committed Aug 14, 2015
1 parent b399095 commit 9791ed4
Show file tree
Hide file tree
Showing 6 changed files with 421 additions and 37 deletions.
3 changes: 2 additions & 1 deletion llvm/include/llvm/Analysis/ScalarEvolution.h
Expand Up @@ -712,7 +712,8 @@ namespace llvm {

/// getNegativeSCEV - Return the SCEV object corresponding to -V.
///
const SCEV *getNegativeSCEV(const SCEV *V);
const SCEV *getNegativeSCEV(const SCEV *V,
SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap);

/// getNotSCEV - Return the SCEV object corresponding to ~V.
///
Expand Down
113 changes: 79 additions & 34 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Expand Up @@ -3339,15 +3339,16 @@ const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {

/// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
///
const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V) {
const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V,
SCEV::NoWrapFlags Flags) {
if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
return getConstant(
cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));

Type *Ty = V->getType();
Ty = getEffectiveSCEVType(Ty);
return getMulExpr(V,
getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))));
return getMulExpr(
V, getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))), Flags);
}

/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
Expand All @@ -3366,15 +3367,40 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
/// getMinusSCEV - Return LHS-RHS. Minus is represented in SCEV as A+B*-1.
const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
SCEV::NoWrapFlags Flags) {
assert(!maskFlags(Flags, SCEV::FlagNUW) && "subtraction does not have NUW");

// Fast path: X - X --> 0.
if (LHS == RHS)
return getConstant(LHS->getType(), 0);

// X - Y --> X + -Y.
// X -(nsw || nuw) Y --> X + -Y.
return getAddExpr(LHS, getNegativeSCEV(RHS));
// We represent LHS - RHS as LHS + (-1)*RHS. This transformation
// makes it so that we cannot make much use of NUW.
auto AddFlags = SCEV::FlagAnyWrap;
const bool RHSIsNotMinSigned =
!getSignedRange(RHS).getSignedMin().isMinSignedValue();
if (maskFlags(Flags, SCEV::FlagNSW) == SCEV::FlagNSW) {
// Let M be the minimum representable signed value. Then (-1)*RHS
// signed-wraps if and only if RHS is M. That can happen even for
// a NSW subtraction because e.g. (-1)*M signed-wraps even though
// -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
// (-1)*RHS, we need to prove that RHS != M.
//
// If LHS is non-negative and we know that LHS - RHS does not
// signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
// either by proving that RHS > M or that LHS >= 0.
if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
AddFlags = SCEV::FlagNSW;
}
}

// FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
// RHS is NSW and LHS >= 0.
//
// The difficulty here is that the NSW flag may have been proven
// relative to a loop that is to be found in a recurrence in LHS and
// not in RHS. Applying NSW to (-1)*M may then let the NSW have a
// larger scope than intended.
auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;

return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags);
}

/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
Expand Down Expand Up @@ -4094,6 +4120,7 @@ ScalarEvolution::getRange(const SCEV *S,
}

SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
const BinaryOperator *BinOp = cast<BinaryOperator>(V);

// Return early if there are no flags to propagate to the SCEV.
Expand Down Expand Up @@ -4185,9 +4212,6 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
// because it leads to N-1 getAddExpr calls for N ultimate operands.
// Instead, gather up all the operands and make a single getAddExpr call.
// LLVM IR canonical form means we need only traverse the left operands.
//
// FIXME: Expand this handling of NSW and NUW to other instructions, like
// sub and mul.
SmallVector<const SCEV *, 4> AddOps;
for (Value *Op = U;; Op = U->getOperand(0)) {
U = dyn_cast<Operator>(Op);
Expand All @@ -4198,7 +4222,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
break;
}

if (auto *OpSCEV = getExistingSCEV(Op)) {
if (auto *OpSCEV = getExistingSCEV(U)) {
AddOps.push_back(OpSCEV);
break;
}
Expand All @@ -4210,45 +4234,57 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
// since the flags are only known to apply to this particular
// addition - they may not apply to other additions that can be
// formed with operands from AddOps.
//
// FIXME: Expand this to sub instructions.
if (Opcode == Instruction::Add && isa<BinaryOperator>(U)) {
SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U);
if (Flags != SCEV::FlagAnyWrap) {
AddOps.push_back(getAddExpr(getSCEV(U->getOperand(0)),
getSCEV(U->getOperand(1)), Flags));
break;
}
const SCEV *RHS = getSCEV(U->getOperand(1));
SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U);
if (Flags != SCEV::FlagAnyWrap) {
const SCEV *LHS = getSCEV(U->getOperand(0));
if (Opcode == Instruction::Sub)
AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
else
AddOps.push_back(getAddExpr(LHS, RHS, Flags));
break;
}

const SCEV *Op1 = getSCEV(U->getOperand(1));
if (Opcode == Instruction::Sub)
AddOps.push_back(getNegativeSCEV(Op1));
AddOps.push_back(getNegativeSCEV(RHS));
else
AddOps.push_back(Op1);
AddOps.push_back(RHS);
}
return getAddExpr(AddOps);
}

case Instruction::Mul: {
// FIXME: Transfer NSW/NUW as in AddExpr.
SmallVector<const SCEV *, 4> MulOps;
MulOps.push_back(getSCEV(U->getOperand(1)));
for (Value *Op = U->getOperand(0);
Op->getValueID() == Instruction::Mul + Value::InstructionVal;
Op = U->getOperand(0)) {
U = cast<Operator>(Op);
for (Value *Op = U;; Op = U->getOperand(0)) {
U = dyn_cast<Operator>(Op);
if (!U || U->getOpcode() != Instruction::Mul) {
assert(Op != V && "V should be a mul");
MulOps.push_back(getSCEV(Op));
break;
}

if (auto *OpSCEV = getExistingSCEV(U)) {
MulOps.push_back(OpSCEV);
break;
}

SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U);
if (Flags != SCEV::FlagAnyWrap) {
MulOps.push_back(getMulExpr(getSCEV(U->getOperand(0)),
getSCEV(U->getOperand(1)), Flags));
break;
}

MulOps.push_back(getSCEV(U->getOperand(1)));
}
MulOps.push_back(getSCEV(U->getOperand(0)));
return getMulExpr(MulOps);
}
case Instruction::UDiv:
return getUDivExpr(getSCEV(U->getOperand(0)),
getSCEV(U->getOperand(1)));
case Instruction::Sub:
return getMinusSCEV(getSCEV(U->getOperand(0)),
getSCEV(U->getOperand(1)));
return getMinusSCEV(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)),
getNoWrapFlagsFromUB(U));
case Instruction::And:
// For an expression like x&255 that merely masks off the high bits,
// use zext(trunc(x)) as the SCEV expression.
Expand Down Expand Up @@ -4368,9 +4404,18 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
if (SA->getValue().uge(BitWidth))
break;

// It is currently not resolved how to interpret NSW for left
// shift by BitWidth - 1, so we avoid applying flags in that
// case. Remove this check (or this comment) once the situation
// is resolved. See
// http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html
// and http://reviews.llvm.org/D8890 .
auto Flags = SCEV::FlagAnyWrap;
if (SA->getValue().ult(BitWidth - 1)) Flags = getNoWrapFlagsFromUB(U);

Constant *X = ConstantInt::get(getContext(),
APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X), Flags);
}
break;

Expand Down
2 changes: 1 addition & 1 deletion llvm/test/Analysis/Delinearization/a.ll
Expand Up @@ -10,7 +10,7 @@
; AddRec: {{{(28 + (4 * (-4 + (3 * %m)) * %o) + %A),+,(8 * %m * %o)}<%for.i>,+,(12 * %o)}<%for.j>,+,20}<%for.k>
; CHECK: Base offset: %A
; CHECK: ArrayDecl[UnknownSize][%m][%o] with elements of 4 bytes.
; CHECK: ArrayRef[{3,+,2}<%for.i>][{-4,+,3}<%for.j>][{7,+,5}<%for.k>]
; CHECK: ArrayRef[{3,+,2}<%for.i>][{-4,+,3}<%for.j>][{7,+,5}<nw><%for.k>]

define void @foo(i64 %n, i64 %m, i64 %o, i32* nocapture %A) #0 {
entry:
Expand Down

0 comments on commit 9791ed4

Please sign in to comment.