Skip to content

Commit

Permalink
[SCEV] Factor out common visiting patterns for SCEV rewriters. NFC.
Browse files Browse the repository at this point in the history
Summary:
Add a SCEVRewriteVisitor class which contains the common
visiting patterns used when rewriting SCEVs.

SCEVParameterRewriter and SCEVApplyRewriter now inherit
from SCEVRewriteVisitor (and are therefore much simpler).

Reviewers: anemet, mzolotukhin, sanjoy

Subscribers: rengolin, llvm-commits, sanjoy

Differential Revision: http://reviews.llvm.org/D13242

llvm-svn: 251283
  • Loading branch information
sbaranga-arm committed Oct 26, 2015
1 parent 29932b0 commit 84d7a40
Showing 1 changed file with 44 additions and 93 deletions.
137 changes: 44 additions & 93 deletions llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
Expand Up @@ -553,82 +553,99 @@ namespace llvm {
T.visitAll(Root);
}

typedef DenseMap<const Value*, Value*> ValueToValueMap;

/// The SCEVParameterRewriter takes a scalar evolution expression and updates
/// the SCEVUnknown components following the Map (Value -> Value).
struct SCEVParameterRewriter
: public SCEVVisitor<SCEVParameterRewriter, const SCEV*> {
/// Recursively visits a SCEV expression and re-writes it.
template<typename SC>
class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
protected:
ScalarEvolution &SE;
public:
static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
ValueToValueMap &Map,
bool InterpretConsts = false) {
SCEVParameterRewriter Rewriter(SE, Map, InterpretConsts);
return Rewriter.visit(Scev);
}

SCEVParameterRewriter(ScalarEvolution &S, ValueToValueMap &M, bool C)
: SE(S), Map(M), InterpretConsts(C) {}
SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}

const SCEV *visitConstant(const SCEVConstant *Constant) {
return Constant;
}

const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
const SCEV *Operand = visit(Expr->getOperand());
const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
return SE.getTruncateExpr(Operand, Expr->getType());
}

const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
const SCEV *Operand = visit(Expr->getOperand());
const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
return SE.getZeroExtendExpr(Operand, Expr->getType());
}

const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
const SCEV *Operand = visit(Expr->getOperand());
const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
return SE.getSignExtendExpr(Operand, Expr->getType());
}

const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
SmallVector<const SCEV *, 2> Operands;
for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
Operands.push_back(visit(Expr->getOperand(i)));
Operands.push_back(((SC*)this)->visit(Expr->getOperand(i)));
return SE.getAddExpr(Operands);
}

const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
SmallVector<const SCEV *, 2> Operands;
for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
Operands.push_back(visit(Expr->getOperand(i)));
Operands.push_back(((SC*)this)->visit(Expr->getOperand(i)));
return SE.getMulExpr(Operands);
}

const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
return SE.getUDivExpr(visit(Expr->getLHS()), visit(Expr->getRHS()));
return SE.getUDivExpr(((SC*)this)->visit(Expr->getLHS()),
((SC*)this)->visit(Expr->getRHS()));
}

const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
SmallVector<const SCEV *, 2> Operands;
for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
Operands.push_back(visit(Expr->getOperand(i)));
Operands.push_back(((SC*)this)->visit(Expr->getOperand(i)));
return SE.getAddRecExpr(Operands, Expr->getLoop(),
Expr->getNoWrapFlags());
}

const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
SmallVector<const SCEV *, 2> Operands;
for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
Operands.push_back(visit(Expr->getOperand(i)));
Operands.push_back(((SC*)this)->visit(Expr->getOperand(i)));
return SE.getSMaxExpr(Operands);
}

const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
SmallVector<const SCEV *, 2> Operands;
for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
Operands.push_back(visit(Expr->getOperand(i)));
Operands.push_back(((SC*)this)->visit(Expr->getOperand(i)));
return SE.getUMaxExpr(Operands);
}

const SCEV *visitUnknown(const SCEVUnknown *Expr) {
return Expr;
}

const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
return Expr;
}
};

typedef DenseMap<const Value*, Value*> ValueToValueMap;

/// The SCEVParameterRewriter takes a scalar evolution expression and updates
/// the SCEVUnknown components following the Map (Value -> Value).
class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
public:
static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
ValueToValueMap &Map,
bool InterpretConsts = false) {
SCEVParameterRewriter Rewriter(SE, Map, InterpretConsts);
return Rewriter.visit(Scev);
}

SCEVParameterRewriter(ScalarEvolution &SE, ValueToValueMap &M, bool C)
: SCEVRewriteVisitor(SE), Map(M), InterpretConsts(C) {}

const SCEV *visitUnknown(const SCEVUnknown *Expr) {
Value *V = Expr->getValue();
if (Map.count(V)) {
Expand All @@ -640,12 +657,7 @@ namespace llvm {
return Expr;
}

const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
return Expr;
}

private:
ScalarEvolution &SE;
ValueToValueMap &Map;
bool InterpretConsts;
};
Expand All @@ -654,54 +666,16 @@ namespace llvm {

/// The SCEVApplyRewriter takes a scalar evolution expression and applies
/// the Map (Loop -> SCEV) to all AddRecExprs.
struct SCEVApplyRewriter
: public SCEVVisitor<SCEVApplyRewriter, const SCEV*> {
class SCEVApplyRewriter : public SCEVRewriteVisitor<SCEVApplyRewriter> {
public:
static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
ScalarEvolution &SE) {
SCEVApplyRewriter Rewriter(SE, Map);
return Rewriter.visit(Scev);
}

SCEVApplyRewriter(ScalarEvolution &S, LoopToScevMapT &M)
: SE(S), Map(M) {}

const SCEV *visitConstant(const SCEVConstant *Constant) {
return Constant;
}

const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
const SCEV *Operand = visit(Expr->getOperand());
return SE.getTruncateExpr(Operand, Expr->getType());
}

const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
const SCEV *Operand = visit(Expr->getOperand());
return SE.getZeroExtendExpr(Operand, Expr->getType());
}

const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
const SCEV *Operand = visit(Expr->getOperand());
return SE.getSignExtendExpr(Operand, Expr->getType());
}

const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
SmallVector<const SCEV *, 2> Operands;
for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
Operands.push_back(visit(Expr->getOperand(i)));
return SE.getAddExpr(Operands);
}

const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
SmallVector<const SCEV *, 2> Operands;
for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
Operands.push_back(visit(Expr->getOperand(i)));
return SE.getMulExpr(Operands);
}

const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
return SE.getUDivExpr(visit(Expr->getLHS()), visit(Expr->getRHS()));
}
SCEVApplyRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
: SCEVRewriteVisitor(SE), Map(M) {}

const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
SmallVector<const SCEV *, 2> Operands;
Expand All @@ -718,30 +692,7 @@ namespace llvm {
return Rec->evaluateAtIteration(Map[L], SE);
}

const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
SmallVector<const SCEV *, 2> Operands;
for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
Operands.push_back(visit(Expr->getOperand(i)));
return SE.getSMaxExpr(Operands);
}

const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
SmallVector<const SCEV *, 2> Operands;
for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
Operands.push_back(visit(Expr->getOperand(i)));
return SE.getUMaxExpr(Operands);
}

const SCEV *visitUnknown(const SCEVUnknown *Expr) {
return Expr;
}

const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
return Expr;
}

private:
ScalarEvolution &SE;
LoopToScevMapT &Map;
};

Expand Down

0 comments on commit 84d7a40

Please sign in to comment.