Skip to content

Commit

Permalink
[SCEV] Sequential/in-order UMin expression
Browse files Browse the repository at this point in the history
As discussed in #53020 / https://reviews.llvm.org/D116692,
SCEV is forbidden from reasoning about 'backedge taken count'
if the branch condition is a poison-safe logical operation,
which is conservatively correct, but is severely limiting.

Instead, we should have a way to express those
poison blocking properties in SCEV expressions.

The proposed semantics is:
```
Sequential/in-order min/max SCEV expressions are non-commutative variants
of commutative min/max SCEV expressions. If none of their operands
are poison, then they are functionally equivalent, otherwise,
if the operand that represents the saturation point* of given expression,
comes before the first poison operand, then the whole expression is not poison,
but is said saturation point.
```
* saturation point - the maximal/minimal possible integer value for the given type

The lowering is straight-forward:
```
compare each operand to the saturation point,
perform sequential in-order logical-or (poison-safe!) ordered reduction
over those checks, and if reduction returned true then return
saturation point else return the naive min/max reduction over the operands
```
https://alive2.llvm.org/ce/z/Q7jxvH (2 ops)
https://alive2.llvm.org/ce/z/QCRrhk (3 ops)
Note that we don't need to check the last operand: https://alive2.llvm.org/ce/z/abvHQS
Note that this is not commutative: https://alive2.llvm.org/ce/z/FK9e97

That allows us to handle the patterns in question.

Reviewed By: nikic, reames

Differential Revision: https://reviews.llvm.org/D116766
  • Loading branch information
LebedevRI committed Jan 10, 2022
1 parent 07a0b0e commit 82fb4f4
Show file tree
Hide file tree
Showing 13 changed files with 339 additions and 86 deletions.
14 changes: 10 additions & 4 deletions llvm/include/llvm/Analysis/ScalarEvolution.h
Expand Up @@ -629,14 +629,18 @@ class ScalarEvolution {
const SCEV *getAbsExpr(const SCEV *Op, bool IsNSW);
const SCEV *getMinMaxExpr(SCEVTypes Kind,
SmallVectorImpl<const SCEV *> &Operands);
const SCEV *getSequentialMinMaxExpr(SCEVTypes Kind,
SmallVectorImpl<const SCEV *> &Operands);
const SCEV *getSMaxExpr(const SCEV *LHS, const SCEV *RHS);
const SCEV *getSMaxExpr(SmallVectorImpl<const SCEV *> &Operands);
const SCEV *getUMaxExpr(const SCEV *LHS, const SCEV *RHS);
const SCEV *getUMaxExpr(SmallVectorImpl<const SCEV *> &Operands);
const SCEV *getSMinExpr(const SCEV *LHS, const SCEV *RHS);
const SCEV *getSMinExpr(SmallVectorImpl<const SCEV *> &Operands);
const SCEV *getUMinExpr(const SCEV *LHS, const SCEV *RHS);
const SCEV *getUMinExpr(SmallVectorImpl<const SCEV *> &Operands);
const SCEV *getUMinExpr(const SCEV *LHS, const SCEV *RHS,
bool Sequential = false);
const SCEV *getUMinExpr(SmallVectorImpl<const SCEV *> &Operands,
bool Sequential = false);
const SCEV *getUnknown(Value *V);
const SCEV *getCouldNotCompute();

Expand Down Expand Up @@ -728,11 +732,13 @@ class ScalarEvolution {

/// Promote the operands to the wider of the types using zero-extension, and
/// then perform a umin operation with them.
const SCEV *getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS);
const SCEV *getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS,
bool Sequential = false);

/// Promote the operands to the wider of the types using zero-extension, and
/// then perform a umin operation with them. N-ary function.
const SCEV *getUMinFromMismatchedTypes(SmallVectorImpl<const SCEV *> &Ops);
const SCEV *getUMinFromMismatchedTypes(SmallVectorImpl<const SCEV *> &Ops,
bool Sequential = false);

/// Transitively follow the chain of pointer-type operands until reaching a
/// SCEV that does not have a single pointer operand. This returns a
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/Analysis/ScalarEvolutionDivision.h
Expand Up @@ -42,6 +42,7 @@ struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> {
void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {}
void visitSMinExpr(const SCEVSMinExpr *Numerator) {}
void visitUMinExpr(const SCEVUMinExpr *Numerator) {}
void visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Numerator) {}
void visitUnknown(const SCEVUnknown *Numerator) {}
void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {}

Expand Down
64 changes: 64 additions & 0 deletions llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
Expand Up @@ -51,6 +51,7 @@ enum SCEVTypes : unsigned short {
scUMinExpr,
scSMinExpr,
scPtrToInt,
scSequentialUMinExpr,
scUnknown,
scCouldNotCompute
};
Expand Down Expand Up @@ -228,6 +229,7 @@ class SCEVNAryExpr : public SCEV {
return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
S->getSCEVType() == scSequentialUMinExpr ||
S->getSCEVType() == scAddRecExpr;
}
};
Expand Down Expand Up @@ -501,6 +503,54 @@ class SCEVUMinExpr : public SCEVMinMaxExpr {
static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; }
};

/// This node is the base class for sequential/in-order min/max selections.
/// Note that their fundamental difference from SCEVMinMaxExpr's is that they
/// are early-returning upon reaching saturation point.
/// I.e. given `0 umin_seq poison`, the result will be `0`,
/// while the result of `0 umin poison` is `poison`.
class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
friend class ScalarEvolution;

static bool isSequentialMinMaxType(enum SCEVTypes T) {
return T == scSequentialUMinExpr;
}

/// Set flags for a non-recurrence without clearing previously set flags.
void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }

protected:
/// Note: Constructing subclasses via this constructor is allowed
SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
const SCEV *const *O, size_t N)
: SCEVNAryExpr(ID, T, O, N) {
assert(isSequentialMinMaxType(T));
// Min and max never overflow
setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
}

public:
Type *getType() const { return getOperand(0)->getType(); }

static bool classof(const SCEV *S) {
return isSequentialMinMaxType(S->getSCEVType());
}
};

/// This class represents a sequential/in-order unsigned minimum selection.
class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr {
friend class ScalarEvolution;

SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O,
size_t N)
: SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {}

public:
/// Methods for support type inquiry through isa, cast, and dyn_cast:
static bool classof(const SCEV *S) {
return S->getSCEVType() == scSequentialUMinExpr;
}
};

/// This means that we are dealing with an entirely unknown SCEV
/// value, and only represent it as its LLVM Value. This is the
/// "bottom" value for the analysis.
Expand Down Expand Up @@ -576,6 +626,9 @@ template <typename SC, typename RetVal = void> struct SCEVVisitor {
return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
case scUMinExpr:
return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
case scSequentialUMinExpr:
return ((SC *)this)
->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S);
case scUnknown:
return ((SC *)this)->visitUnknown((const SCEVUnknown *)S);
case scCouldNotCompute:
Expand Down Expand Up @@ -630,6 +683,7 @@ template <typename SV> class SCEVTraversal {
case scUMaxExpr:
case scSMinExpr:
case scUMinExpr:
case scSequentialUMinExpr:
case scAddRecExpr:
for (const auto *Op : cast<SCEVNAryExpr>(S)->operands())
push(Op);
Expand Down Expand Up @@ -815,6 +869,16 @@ class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
return !Changed ? Expr : SE.getUMinExpr(Operands);
}

const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
SmallVector<const SCEV *, 2> Operands;
bool Changed = false;
for (auto *Op : Expr->operands()) {
Operands.push_back(((SC *)this)->visit(Op));
Changed |= Op != Operands.back();
}
return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true);
}

const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; }

const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
Expand Down
9 changes: 9 additions & 0 deletions llvm/include/llvm/IR/IRBuilder.h
Expand Up @@ -1571,6 +1571,15 @@ class IRBuilderBase {
Cond2, Name);
}

// NOTE: this is sequential, non-commutative, ordered reduction!
Value *CreateLogicalOr(ArrayRef<Value *> Ops) {
assert(!Ops.empty());
Value *Accum = Ops[0];
for (unsigned i = 1; i < Ops.size(); i++)
Accum = CreateLogicalOr(Accum, Ops[i]);
return Accum;
}

CallInst *CreateConstrainedFPBinOp(
Intrinsic::ID ID, Value *L, Value *R, Instruction *FMFSource = nullptr,
const Twine &Name = "", MDNode *FPMathTag = nullptr,
Expand Down
10 changes: 10 additions & 0 deletions llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
Expand Up @@ -450,6 +450,14 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
/// Determine the most "relevant" loop for the given SCEV.
const Loop *getRelevantLoop(const SCEV *);

Value *expandSMaxExpr(const SCEVNAryExpr *S);

Value *expandUMaxExpr(const SCEVNAryExpr *S);

Value *expandSMinExpr(const SCEVNAryExpr *S);

Value *expandUMinExpr(const SCEVNAryExpr *S);

Value *visitConstant(const SCEVConstant *S) { return S->getValue(); }

Value *visitPtrToIntExpr(const SCEVPtrToIntExpr *S);
Expand All @@ -476,6 +484,8 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {

Value *visitUMinExpr(const SCEVUMinExpr *S);

Value *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S);

Value *visitUnknown(const SCEVUnknown *S) { return S->getValue(); }

void rememberInstruction(Value *I);
Expand Down

0 comments on commit 82fb4f4

Please sign in to comment.