Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions llvm/include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,10 @@ class ScalarEvolution {
getUMinFromMismatchedTypes(SmallVectorImpl<const SCEV *> &Ops,
bool Sequential = false);

/// Try to match the pattern generated by getURemExpr(A, B). If successful,
/// Assign A and B to LHS and RHS, respectively.
LLVM_ABI bool matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS);

/// Transitively follow the chain of pointer-type operands until reaching a
/// SCEV that does not have a single pointer operand. This returns a
/// SCEVUnknown pointer for well-formed pointer-type expressions, but corner
Expand Down Expand Up @@ -2316,10 +2320,6 @@ class ScalarEvolution {
/// an add rec on said loop.
void getUsedLoops(const SCEV *S, SmallPtrSetImpl<const Loop *> &LoopsUsed);

/// Try to match the pattern generated by getURemExpr(A, B). If successful,
/// Assign A and B to LHS and RHS, respectively.
LLVM_ABI bool matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS);

/// Look for a SCEV expression with type `SCEVType` and operands `Ops` in
/// `UniqueSCEVs`. Return if found, else nullptr.
SCEV *findExistingSCEVInCache(SCEVTypes SCEVType, ArrayRef<const SCEV *> Ops);
Expand Down
290 changes: 165 additions & 125 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15557,6 +15557,123 @@ void ScalarEvolution::LoopGuards::collectFromPHI(
}
}

// Checks whether Expr is a non-negative constant, and Divisor is a positive
// constant, and returns their APInt in ExprVal and in DivisorVal.
static bool getNonNegExprAndPosDivisor(const SCEV *Expr, const SCEV *Divisor,
APInt &ExprVal, APInt &DivisorVal) {
auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
if (!ConstExpr || !ConstDivisor)
return false;
ExprVal = ConstExpr->getAPInt();
DivisorVal = ConstDivisor->getAPInt();
return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
}

// Return a new SCEV that modifies \p Expr to the closest number divisible by
// \p Divisor and less than or equal to Expr.
// For now, only handle constant Expr and Divisor.
static const SCEV *getPreviousSCEVDivisibleByDivisor(const SCEV *Expr,
const SCEV *Divisor,
ScalarEvolution &SE) {
APInt ExprVal;
APInt DivisorVal;
if (!getNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
return Expr;
APInt Rem = ExprVal.urem(DivisorVal);
// return the SCEV: Expr - Expr % Divisor
return SE.getConstant(ExprVal - Rem);
}

// Return a new SCEV that modifies \p Expr to the closest number divisible by
// \p Divisor and greater than or equal to Expr.
// For now, only handle constant Expr and Divisor.
static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
const SCEV *Divisor,
ScalarEvolution &SE) {
APInt ExprVal;
APInt DivisorVal;
if (!getNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
return Expr;
APInt Rem = ExprVal.urem(DivisorVal);
if (!Rem.isZero())
// return the SCEV: Expr + Divisor - Expr % Divisor
return SE.getConstant(ExprVal + DivisorVal - Rem);
return Expr;
}

static bool collectDivisibilityInformation(
ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
DenseMap<const SCEV *, const SCEV *> &DivInfo,
DenseMap<const SCEV *, const SCEV *> &Multiples, ScalarEvolution &SE) {
// If we have LHS == 0, check if LHS is computing a property of some unknown
// SCEV %v which we can rewrite %v to express explicitly.
if (Predicate != CmpInst::ICMP_EQ || !match(RHS, m_scev_Zero()))
return false;
// If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
// explicitly express that.
const SCEV *URemLHS = nullptr;
const SCEV *URemRHS = nullptr;
if (!SE.matchURem(LHS, URemLHS, URemRHS))
return false;
if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
const auto *Multiple = SE.getMulExpr(SE.getUDivExpr(LHS, URemRHS), URemRHS);
DivInfo[LHSUnknown] = Multiple;
Multiples[LHSUnknown] = URemRHS;
return true;
}
return false;
}

// Check if the condition is a divisibility guard (A % B == 0).
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
ScalarEvolution &SE) {
const SCEV *X, *Y;
return SE.matchURem(LHS, X, Y) && RHS->isZero();
}

// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
// recursively. This is done by aligning up/down the constant value to the
// Divisor.
static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
const SCEV *Divisor,
ScalarEvolution &SE) {
// Return true if \p Expr is a MinMax SCEV expression with a non-negative
// constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
// the non-constant operand and in \p LHS the constant operand.
auto IsMinMaxSCEVWithNonNegativeConstant =
[](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
const SCEV *&RHS) {
if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
if (MinMax->getNumOperands() != 2)
return false;
if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
if (C->getAPInt().isNegative())
return false;
SCTy = MinMax->getSCEVType();
LHS = MinMax->getOperand(0);
RHS = MinMax->getOperand(1);
return true;
}
}
return false;
};

const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
SCEVTypes SCTy;
if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
MinMaxRHS))
return MinMaxExpr;
auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
auto *DivisibleExpr =
IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
: getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
SmallVector<const SCEV *> Ops = {
applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
return SE.getMinMaxExpr(SCTy, Ops);
}

void ScalarEvolution::LoopGuards::collectFromBlock(
ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
const BasicBlock *Block, const BasicBlock *Pred,
Expand All @@ -15567,19 +15684,14 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
SmallVector<const SCEV *> ExprsToRewrite;
auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
const SCEV *RHS,
DenseMap<const SCEV *, const SCEV *>
&RewriteMap) {
DenseMap<const SCEV *, const SCEV *> &RewriteMap,
const DenseMap<const SCEV *, const SCEV *>
&DivInfo) {
// WARNING: It is generally unsound to apply any wrap flags to the proposed
// replacement SCEV which isn't directly implied by the structure of that
// SCEV. In particular, using contextual facts to imply flags is *NOT*
// legal. See the scoping rules for flags in the header to understand why.

// If LHS is a constant, apply information to the other expression.
if (isa<SCEVConstant>(LHS)) {
std::swap(LHS, RHS);
Predicate = CmpInst::getSwappedPredicate(Predicate);
}

// Check for a condition of the form (-C1 + X < C2). InstCombine will
// create this form when combining two checks of the form (X u< C2 + C1) and
// (X >=u C1).
Expand Down Expand Up @@ -15612,115 +15724,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
if (MatchRangeCheckIdiom())
return;

// Return true if \p Expr is a MinMax SCEV expression with a non-negative
// constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
// the non-constant operand and in \p LHS the constant operand.
auto IsMinMaxSCEVWithNonNegativeConstant =
[&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
const SCEV *&RHS) {
if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
if (MinMax->getNumOperands() != 2)
return false;
if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
if (C->getAPInt().isNegative())
return false;
SCTy = MinMax->getSCEVType();
LHS = MinMax->getOperand(0);
RHS = MinMax->getOperand(1);
return true;
}
}
return false;
};

// Checks whether Expr is a non-negative constant, and Divisor is a positive
// constant, and returns their APInt in ExprVal and in DivisorVal.
auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
APInt &ExprVal, APInt &DivisorVal) {
auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
if (!ConstExpr || !ConstDivisor)
return false;
ExprVal = ConstExpr->getAPInt();
DivisorVal = ConstDivisor->getAPInt();
return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
};

// Return a new SCEV that modifies \p Expr to the closest number divides by
// \p Divisor and greater or equal than Expr.
// For now, only handle constant Expr and Divisor.
auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
const SCEV *Divisor) {
APInt ExprVal;
APInt DivisorVal;
if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
return Expr;
APInt Rem = ExprVal.urem(DivisorVal);
if (!Rem.isZero())
// return the SCEV: Expr + Divisor - Expr % Divisor
return SE.getConstant(ExprVal + DivisorVal - Rem);
return Expr;
};

// Return a new SCEV that modifies \p Expr to the closest number divides by
// \p Divisor and less or equal than Expr.
// For now, only handle constant Expr and Divisor.
auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
const SCEV *Divisor) {
APInt ExprVal;
APInt DivisorVal;
if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
return Expr;
APInt Rem = ExprVal.urem(DivisorVal);
// return the SCEV: Expr - Expr % Divisor
return SE.getConstant(ExprVal - Rem);
};

// Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
// recursively. This is done by aligning up/down the constant value to the
// Divisor.
std::function<const SCEV *(const SCEV *, const SCEV *)>
ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
const SCEV *Divisor) {
const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
SCEVTypes SCTy;
if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
MinMaxRHS))
return MinMaxExpr;
auto IsMin =
isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
assert(SE.isKnownNonNegative(MinMaxLHS) &&
"Expected non-negative operand!");
auto *DivisibleExpr =
IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
: GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
SmallVector<const SCEV *> Ops = {
ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
return SE.getMinMaxExpr(SCTy, Ops);
};

// If we have LHS == 0, check if LHS is computing a property of some unknown
// SCEV %v which we can rewrite %v to express explicitly.
if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
// If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
// explicitly express that.
const SCEV *URemLHS = nullptr;
const SCEV *URemRHS = nullptr;
if (SE.matchURem(LHS, URemLHS, URemRHS)) {
if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
auto I = RewriteMap.find(LHSUnknown);
const SCEV *RewrittenLHS =
I != RewriteMap.end() ? I->second : LHSUnknown;
RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
const auto *Multiple =
SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
RewriteMap[LHSUnknown] = Multiple;
ExprsToRewrite.push_back(LHSUnknown);
return;
}
}
}

// Do not apply information for constants or if RHS contains an AddRec.
if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
return;
Expand Down Expand Up @@ -15751,7 +15754,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(

const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
const SCEV *DividesBy = nullptr;
const APInt &Multiple = SE.getConstantMultiple(RewrittenLHS);
// Apply divisibility information when computing the constant multiple.
LoopGuards DivGuards(SE);
DivGuards.RewriteMap = DivInfo;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this copying the map?

const APInt &Multiple =
SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
if (!Multiple.isOne())
DividesBy = SE.getConstant(Multiple);

Expand All @@ -15775,21 +15782,23 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
[[fallthrough]];
case CmpInst::ICMP_SLT: {
RHS = SE.getMinusSCEV(RHS, One);
RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
RHS = DividesBy ? getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE)
: RHS;
break;
}
case CmpInst::ICMP_UGT:
case CmpInst::ICMP_SGT:
RHS = SE.getAddExpr(RHS, One);
RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
RHS = DividesBy ? getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE) : RHS;
break;
case CmpInst::ICMP_ULE:
case CmpInst::ICMP_SLE:
RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
RHS = DividesBy ? getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE)
: RHS;
break;
case CmpInst::ICMP_UGE:
case CmpInst::ICMP_SGE:
RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
RHS = DividesBy ? getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE) : RHS;
break;
default:
break;
Expand Down Expand Up @@ -15843,7 +15852,8 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
case CmpInst::ICMP_NE:
if (match(RHS, m_scev_Zero())) {
const SCEV *OneAlignedUp =
DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
DividesBy ? getNextSCEVDivisibleByDivisor(One, DividesBy, SE)
: One;
To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
}
break;
Expand Down Expand Up @@ -15916,8 +15926,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(

// Now apply the information from the collected conditions to
// Guards.RewriteMap. Conditions are processed in reverse order, so the
// earliest conditions is processed first. This ensures the SCEVs with the
// earliest conditions is processed first, except guards with divisibility
// information, which are moved to the back. This ensures the SCEVs with the
// shortest dependency chains are constructed first.
SmallVector<std::tuple<CmpInst::Predicate, const SCEV *, const SCEV *>>
GuardsToProcess;
for (auto [Term, EnterIfTrue] : reverse(Terms)) {
SmallVector<Value *, 8> Worklist;
SmallPtrSet<Value *, 8> Visited;
Expand All @@ -15932,7 +15945,12 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
// If LHS is a constant, apply information to the other expression.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And if LHS is not a constant, we make an arbitrary choice? Should we be using CompareSCEVComplexity or something?

if (isa<SCEVConstant>(LHS)) {
std::swap(LHS, RHS);
Predicate = CmpInst::getSwappedPredicate(Predicate);
}
GuardsToProcess.emplace_back(Predicate, LHS, RHS);
continue;
}

Expand All @@ -15945,6 +15963,28 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
}
}

// Process divisibility guards in reverse order to populate DivInfo early.
DenseMap<const SCEV *, const SCEV *> Multiples;
DenseMap<const SCEV *, const SCEV *> DivInfo;
for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
if (!isDivisibilityGuard(LHS, RHS, SE))
continue;
collectDivisibilityInformation(Predicate, LHS, RHS, DivInfo, Multiples, SE);
}

for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivInfo);

// Apply divisibility information last. This ensures it is applied to the
// outermost expression after other rewrites for the given value.
for (const auto &[K, V] : Multiples) {
Guards.RewriteMap[K] = SE.getMulExpr(
SE.getUDivExpr(applyDivisibilityOnMinMaxExpr(Guards.rewrite(K), V, SE),
V),
V);
ExprsToRewrite.push_back(K);
}

// Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
// the replacement expressions are contained in the ranges of the replaced
// expressions.
Expand Down
Loading
Loading