Skip to content

Commit

Permalink
[SCEV] Generalize SCEVParameterRewriter to accept SCEV expression as …
Browse files Browse the repository at this point in the history
…target.

This patch extends SCEVParameterRewriter to support rewriting unknown
epxressions to arbitrary SCEV expressions. It will be used by further
patches.

Reviewed By: reames

Differential Revision: https://reviews.llvm.org/D67176
  • Loading branch information
fhahn committed Sep 18, 2020
1 parent c102005 commit 4635f60
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 23 deletions.
27 changes: 11 additions & 16 deletions llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
Original file line number Diff line number Diff line change
Expand Up @@ -810,35 +810,30 @@ class Type;
};

using ValueToValueMap = DenseMap<const Value *, Value *>;
using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;

/// The SCEVParameterRewriter takes a scalar evolution expression and updates
/// the SCEVUnknown components following the Map (Value -> Value).
/// the SCEVUnknown components following the Map (Value -> SCEV).
class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
public:
static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
ValueToValueMap &Map,
bool InterpretConsts = false) {
SCEVParameterRewriter Rewriter(SE, Map, InterpretConsts);
ValueToSCEVMapTy &Map) {
SCEVParameterRewriter Rewriter(SE, Map);
return Rewriter.visit(Scev);
}

SCEVParameterRewriter(ScalarEvolution &SE, ValueToValueMap &M, bool C)
: SCEVRewriteVisitor(SE), Map(M), InterpretConsts(C) {}
SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
: SCEVRewriteVisitor(SE), Map(M) {}

const SCEV *visitUnknown(const SCEVUnknown *Expr) {
Value *V = Expr->getValue();
if (Map.count(V)) {
Value *NV = Map[V];
if (InterpretConsts && isa<ConstantInt>(NV))
return SE.getConstant(cast<ConstantInt>(NV));
return SE.getUnknown(NV);
}
return Expr;
auto I = Map.find(Expr->getValue());
if (I == Map.end())
return Expr;
return I->second;
}

private:
ValueToValueMap &Map;
bool InterpretConsts;
ValueToSCEVMapTy &Map;
};

using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
Expand Down
12 changes: 5 additions & 7 deletions llvm/lib/Analysis/ScalarEvolutionDivision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,16 +215,14 @@ void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) {
return cannotDivide(Numerator);

// The Remainder is obtained by replacing Denominator by 0 in Numerator.
ValueToValueMap RewriteMap;
RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
cast<SCEVConstant>(Zero)->getValue();
Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
ValueToSCEVMapTy RewriteMap;
RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero;
Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);

if (Remainder->isZero()) {
// The Quotient is obtained by replacing Denominator by 1 in Numerator.
RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
cast<SCEVConstant>(One)->getValue();
Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One;
Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
return;
}

Expand Down
45 changes: 45 additions & 0 deletions llvm/unittests/Analysis/ScalarEvolutionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ static Instruction *getInstructionByName(Function &F, StringRef Name) {
llvm_unreachable("Expected to find instruction!");
}

static Value *getArgByName(Function &F, StringRef Name) {
for (auto &Arg : F.args())
if (Arg.getName() == Name)
return &Arg;
llvm_unreachable("Expected to find instruction!");
}
TEST_F(ScalarEvolutionsTest, CommutativeExprOperandOrder) {
LLVMContext C;
SMDiagnostic Err;
Expand Down Expand Up @@ -1120,4 +1126,43 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
});
}

TEST_F(ScalarEvolutionsTest, SCEVrewriteUnknowns) {
LLVMContext C;
SMDiagnostic Err;
std::unique_ptr<Module> M = parseAssemblyString(
"define void @foo(i32 %i) { "
"entry: "
" %cmp3 = icmp ult i32 %i, 16 "
" br i1 %cmp3, label %loop.body, label %exit "
"loop.body: "
" %iv = phi i32 [ %iv.next, %loop.body ], [ %i, %entry ] "
" %iv.next = add nsw i32 %iv, 1 "
" %cmp = icmp eq i32 %iv.next, 16 "
" br i1 %cmp, label %exit, label %loop.body "
"exit: "
" ret void "
"} ",
Err, C);

ASSERT_TRUE(M && "Could not parse module?");
ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");

runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv")); // {0,+,1}
auto *ScevI = SE.getSCEV(getArgByName(F, "i")); // {0,+,1}

ValueToSCEVMapTy RewriteMap;
RewriteMap[cast<SCEVUnknown>(ScevI)->getValue()] =
SE.getUMinExpr(ScevI, SE.getConstant(ScevI->getType(), 17));
auto *WithUMin = SCEVParameterRewriter::rewrite(ScevIV, SE, RewriteMap);

EXPECT_NE(WithUMin, ScevIV);
auto *AR = dyn_cast<SCEVAddRecExpr>(WithUMin);
EXPECT_TRUE(AR);
EXPECT_EQ(AR->getStart(),
SE.getUMinExpr(ScevI, SE.getConstant(ScevI->getType(), 17)));
EXPECT_EQ(AR->getStepRecurrence(SE),
cast<SCEVAddRecExpr>(ScevIV)->getStepRecurrence(SE));
});
}
} // end namespace llvm

0 comments on commit 4635f60

Please sign in to comment.