432 changes: 260 additions & 172 deletions llvm/lib/Transforms/Scalar/LoopFlatten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
//
// The intention is to optimise loop nests like this, which together access an
// array linearly:
//
// for (int i = 0; i < N; ++i)
// for (int j = 0; j < M; ++j)
// f(A[i*M+j]);
//
// into one loop:
//
// for (int i = 0; i < (N*M); ++i)
// f(A[i]);
//
Expand All @@ -22,7 +25,27 @@
// expression like i*M+j. If they had any other uses, we would have to insert a
// div/mod to reconstruct the original values, so this wouldn't be profitable.
//
// We also need to prove that N*M will not overflow.
// We also need to prove that N*M will not overflow. The preferred solution is
// to widen the IV, which avoids overflow checks, so that is tried first. If
// the IV cannot be widened, then we try to determine that this new tripcount
// expression won't overflow.
//
// Q: Does LoopFlatten use SCEV?
// Short answer: Yes and no.
//
// Long answer:
// For this transformation to be valid, we require all uses of the induction
// variables to be linear expressions of the form i*M+j. The different Loop
// APIs are used to get some loop components like the induction variable,
// compare statement, etc. In addition, we do some pattern matching to find the
// linear expressions and other loop components like the loop increment. The
// latter are examples of expressions that do use the induction variable, but
// are safe to ignore when we check all uses to be of the form i*M+j. We keep
// track of all of this in bookkeeping struct FlattenInfo.
// We assume the loops to be canonical, i.e. starting at 0 and increment with
// 1. This makes RHS of the compare the loop tripcount (with the right
// predicate). We use SCEV to then sanity check that this tripcount matches
// with the tripcount as computed by SCEV.
//
//===----------------------------------------------------------------------===//

Expand Down Expand Up @@ -75,30 +98,48 @@ static cl::opt<bool>
cl::desc("Widen the loop induction variables, if possible, so "
"overflow checks won't reject flattening"));

// We require all uses of both induction variables to match this pattern:
//
// (OuterPHI * InnerTripCount) + InnerPHI
//
// I.e., it needs to be a linear expression of the induction variables and the
// inner loop trip count. We keep track of all different expressions on which
// checks will be performed in this bookkeeping struct.
//
struct FlattenInfo {
Loop *OuterLoop = nullptr;
Loop *OuterLoop = nullptr; // The loop pair to be flattened.
Loop *InnerLoop = nullptr;
// These PHINodes correspond to loop induction variables, which are expected
// to start at zero and increment by one on each loop.
PHINode *InnerInductionPHI = nullptr;
PHINode *OuterInductionPHI = nullptr;
Value *InnerTripCount = nullptr;
Value *OuterTripCount = nullptr;
BinaryOperator *InnerIncrement = nullptr;
BinaryOperator *OuterIncrement = nullptr;
BranchInst *InnerBranch = nullptr;
BranchInst *OuterBranch = nullptr;
SmallPtrSet<Value *, 4> LinearIVUses;

PHINode *InnerInductionPHI = nullptr; // These PHINodes correspond to loop
PHINode *OuterInductionPHI = nullptr; // induction variables, which are
// expected to start at zero and
// increment by one on each loop.

Value *InnerTripCount = nullptr; // The product of these two tripcounts
Value *OuterTripCount = nullptr; // will be the new flattened loop
// tripcount. Also used to recognise a
// linear expression that will be replaced.

SmallPtrSet<Value *, 4> LinearIVUses; // Contains the linear expressions
// of the form i*M+j that will be
// replaced.

BinaryOperator *InnerIncrement = nullptr; // Uses of induction variables in
BinaryOperator *OuterIncrement = nullptr; // loop control statements that
BranchInst *InnerBranch = nullptr; // are safe to ignore.

BranchInst *OuterBranch = nullptr; // The instruction that needs to be
// updated with new tripcount.

SmallPtrSet<PHINode *, 4> InnerPHIsToTransform;

// Whether this holds the flatten info before or after widening.
bool Widened = false;
bool Widened = false; // Whether this holds the flatten info before or after
// widening.

// Holds the old/narrow induction phis, i.e. the Phis before IV widening has
// been applied. This bookkeeping is used so we can skip some checks on these
// phi nodes.
PHINode *NarrowInnerInductionPHI = nullptr;
PHINode *NarrowOuterInductionPHI = nullptr;
PHINode *NarrowInnerInductionPHI = nullptr; // Holds the old/narrow induction
PHINode *NarrowOuterInductionPHI = nullptr; // phis, i.e. the Phis before IV
// has been apllied. Used to skip
// checks on phi nodes.

FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL){};

Expand All @@ -108,6 +149,118 @@ struct FlattenInfo {
return false;
return NarrowInnerInductionPHI == Phi || NarrowOuterInductionPHI == Phi;
}
bool isInnerLoopIncrement(User *U) {
return InnerIncrement == U;
}
bool isOuterLoopIncrement(User *U) {
return OuterIncrement == U;
}
bool isInnerLoopTest(User *U) {
return InnerBranch->getCondition() == U;
}

bool checkOuterInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
for (User *U : OuterInductionPHI->users()) {
if (isOuterLoopIncrement(U))
continue;

auto IsValidOuterPHIUses = [&] (User *U) -> bool {
LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump());
if (!ValidOuterPHIUses.count(U)) {
LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
return false;
}
LLVM_DEBUG(dbgs() << "Use is optimisable\n");
return true;
};

if (auto *V = dyn_cast<TruncInst>(U)) {
for (auto *K : V->users()) {
if (!IsValidOuterPHIUses(K))
return false;
}
continue;
}

if (!IsValidOuterPHIUses(U))
return false;
}
return true;
}

bool matchLinearIVUser(User *U, Value *InnerTripCount,
SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump());
Value *MatchedMul = nullptr;
Value *MatchedItCount = nullptr;

bool IsAdd = match(U, m_c_Add(m_Specific(InnerInductionPHI),
m_Value(MatchedMul))) &&
match(MatchedMul, m_c_Mul(m_Specific(OuterInductionPHI),
m_Value(MatchedItCount)));

// Matches the same pattern as above, except it also looks for truncs
// on the phi, which can be the result of widening the induction variables.
bool IsAddTrunc =
match(U, m_c_Add(m_Trunc(m_Specific(InnerInductionPHI)),
m_Value(MatchedMul))) &&
match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(OuterInductionPHI)),
m_Value(MatchedItCount)));

if (!MatchedItCount)
return false;

// Look through extends if the IV has been widened.
if (Widened &&
(isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) {
assert(MatchedItCount->getType() == InnerInductionPHI->getType() &&
"Unexpected type mismatch in types after widening");
MatchedItCount = isa<SExtInst>(MatchedItCount)
? dyn_cast<SExtInst>(MatchedItCount)->getOperand(0)
: dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0);
}

if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
LLVM_DEBUG(dbgs() << "Use is optimisable\n");
ValidOuterPHIUses.insert(MatchedMul);
LinearIVUses.insert(U);
return true;
}

LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
return false;
}

bool checkInnerInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {
Value *SExtInnerTripCount = InnerTripCount;
if (Widened &&
(isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount)))
SExtInnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0);

for (User *U : InnerInductionPHI->users()) {
if (isInnerLoopIncrement(U))
continue;

// After widening the IVs, a trunc instruction might have been introduced,
// so look through truncs.
if (isa<TruncInst>(U)) {
if (!U->hasOneUse())
return false;
U = *U->user_begin();
}

// If the use is in the compare (which is also the condition of the inner
// branch) then the compare has been altered by another transformation e.g
// icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is
// a constant. Ignore this use as the compare gets removed later anyway.
if (isInnerLoopTest(U))
continue;

if (!matchLinearIVUser(U, SExtInnerTripCount, ValidOuterPHIUses))
return false;
}
return true;
}
};

static bool
Expand All @@ -121,6 +274,77 @@ setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment,
return true;
}

// Given the RHS of the loop latch compare instruction, verify with SCEV
// that this is indeed the loop tripcount.
// TODO: This used to be a straightforward check but has grown to be quite
// complicated now. It is therefore worth revisiting what the additional
// benefits are of this (compared to relying on canonical loops and pattern
// matching).
static bool verifyTripCount(Value *RHS, Loop *L,
SmallPtrSetImpl<Instruction *> &IterationInstructions,
PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment,
BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) {
const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n");
return false;
}

// The Extend=false flag is used for getTripCountFromExitCount as we want
// to verify and match it with the pattern matched tripcount. Please note
// that overflow checks are performed in checkOverflow, but are first tried
// to avoid by widening the IV.
const SCEV *SCEVTripCount =
SE->getTripCountFromExitCount(BackedgeTakenCount, /*Extend=*/false);

const SCEV *SCEVRHS = SE->getSCEV(RHS);
if (SCEVRHS == SCEVTripCount)
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS);
if (ConstantRHS) {
const SCEV *BackedgeTCExt = nullptr;
if (IsWidened) {
const SCEV *SCEVTripCountExt;
// Find the extended backedge taken count and extended trip count using
// SCEV. One of these should now match the RHS of the compare.
BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false);
if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return false;
}
}
// If the RHS of the compare is equal to the backedge taken count we need
// to add one to get the trip count.
if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
Value *NewRHS = ConstantInt::get(
ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue());
return setLoopComponents(NewRHS, TripCount, Increment,
IterationInstructions);
}
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
}
// If the RHS isn't a constant then check that the reason it doesn't match
// the SCEV trip count is because the RHS is a ZExt or SExt instruction
// (and take the trip count to be the RHS).
if (!IsWidened) {
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return false;
}
auto *TripCountInst = dyn_cast<Instruction>(RHS);
if (!TripCountInst) {
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return false;
}
if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
return false;
}
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
}

// Finds the induction variable, increment and trip count for a simple loop that
// we can flatten.
static bool findLoopComponents(
Expand Down Expand Up @@ -197,63 +421,9 @@ static bool findLoopComponents(
// another transformation has changed the compare (e.g. icmp ult %inc,
// tripcount -> icmp ult %j, tripcount-1), or both.
Value *RHS = Compare->getOperand(1);
const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n");
return false;
}
// The use of the Extend=false flag on getTripCountFromExitCount was added
// during a refactoring to preserve existing behavior. However, there's
// nothing obvious in the surrounding code when handles the overflow case.
// FIXME: audit code to establish whether there's a latent bug here.
const SCEV *SCEVTripCount =
SE->getTripCountFromExitCount(BackedgeTakenCount, false);
const SCEV *SCEVRHS = SE->getSCEV(RHS);
if (SCEVRHS == SCEVTripCount)
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS);
if (ConstantRHS) {
const SCEV *BackedgeTCExt = nullptr;
if (IsWidened) {
const SCEV *SCEVTripCountExt;
// Find the extended backedge taken count and extended trip count using
// SCEV. One of these should now match the RHS of the compare.
BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false);
if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return false;
}
}
// If the RHS of the compare is equal to the backedge taken count we need
// to add one to get the trip count.
if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
Value *NewRHS = ConstantInt::get(
ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue());
return setLoopComponents(NewRHS, TripCount, Increment,
IterationInstructions);
}
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
}
// If the RHS isn't a constant then check that the reason it doesn't match
// the SCEV trip count is because the RHS is a ZExt or SExt instruction
// (and take the trip count to be the RHS).
if (!IsWidened) {
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return false;
}
auto *TripCountInst = dyn_cast<Instruction>(RHS);
if (!TripCountInst) {
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return false;
}
if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
return false;
}
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);

return verifyTripCount(RHS, L, IterationInstructions, InductionPHI, TripCount,
Increment, BackBranch, SE, IsWidened);
}

static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) {
Expand Down Expand Up @@ -399,108 +569,26 @@ checkOuterLoopInsts(FlattenInfo &FI,
return true;
}

static bool checkIVUsers(FlattenInfo &FI) {
// We require all uses of both induction variables to match this pattern:
//
// (OuterPHI * InnerTripCount) + InnerPHI
//
// Any uses of the induction variables not matching that pattern would
// require a div/mod to reconstruct in the flattened loop, so the
// transformation wouldn't be profitable.

Value *InnerTripCount = FI.InnerTripCount;
if (FI.Widened &&
(isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount)))
InnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0);


// We require all uses of both induction variables to match this pattern:
//
// (OuterPHI * InnerTripCount) + InnerPHI
//
// Any uses of the induction variables not matching that pattern would
// require a div/mod to reconstruct in the flattened loop, so the
// transformation wouldn't be profitable.
static bool checkIVUsers(FlattenInfo &FI) {
// Check that all uses of the inner loop's induction variable match the
// expected pattern, recording the uses of the outer IV.
SmallPtrSet<Value *, 4> ValidOuterPHIUses;
for (User *U : FI.InnerInductionPHI->users()) {
if (U == FI.InnerIncrement)
continue;

// After widening the IVs, a trunc instruction might have been introduced,
// so look through truncs.
if (isa<TruncInst>(U)) {
if (!U->hasOneUse())
return false;
U = *U->user_begin();
}

// If the use is in the compare (which is also the condition of the inner
// branch) then the compare has been altered by another transformation e.g
// icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is
// a constant. Ignore this use as the compare gets removed later anyway.
if (U == FI.InnerBranch->getCondition())
continue;

LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump());

Value *MatchedMul = nullptr;
Value *MatchedItCount = nullptr;
bool IsAdd = match(U, m_c_Add(m_Specific(FI.InnerInductionPHI),
m_Value(MatchedMul))) &&
match(MatchedMul, m_c_Mul(m_Specific(FI.OuterInductionPHI),
m_Value(MatchedItCount)));

// Matches the same pattern as above, except it also looks for truncs
// on the phi, which can be the result of widening the induction variables.
bool IsAddTrunc =
match(U, m_c_Add(m_Trunc(m_Specific(FI.InnerInductionPHI)),
m_Value(MatchedMul))) &&
match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)),
m_Value(MatchedItCount)));

if (!MatchedItCount)
return false;
// Look through extends if the IV has been widened.
if (FI.Widened &&
(isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) {
assert(MatchedItCount->getType() == FI.InnerInductionPHI->getType() &&
"Unexpected type mismatch in types after widening");
MatchedItCount = isa<SExtInst>(MatchedItCount)
? dyn_cast<SExtInst>(MatchedItCount)->getOperand(0)
: dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0);
}

if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
LLVM_DEBUG(dbgs() << "Use is optimisable\n");
ValidOuterPHIUses.insert(MatchedMul);
FI.LinearIVUses.insert(U);
} else {
LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
return false;
}
}
if (!FI.checkInnerInductionPhiUsers(ValidOuterPHIUses))
return false;

// Check that there are no uses of the outer IV other than the ones found
// as part of the pattern above.
for (User *U : FI.OuterInductionPHI->users()) {
if (U == FI.OuterIncrement)
continue;

auto IsValidOuterPHIUses = [&] (User *U) -> bool {
LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump());
if (!ValidOuterPHIUses.count(U)) {
LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
return false;
}
LLVM_DEBUG(dbgs() << "Use is optimisable\n");
return true;
};

if (auto *V = dyn_cast<TruncInst>(U)) {
for (auto *K : V->users()) {
if (!IsValidOuterPHIUses(K))
return false;
}
continue;
}

if (!IsValidOuterPHIUses(U))
return false;
}
if (!FI.checkOuterInductionPhiUsers(ValidOuterPHIUses))
return false;

LLVM_DEBUG(dbgs() << "checkIVUsers: OK\n";
dbgs() << "Found " << FI.LinearIVUses.size()
Expand Down