Skip to content

Commit

Permalink
[LV] Clean up LoopVectorizationCostModel::calculateRegisterUsage. NFC
Browse files Browse the repository at this point in the history
Minor refactoring in LoopVectorizationCostModel::calculateRegisterUsage.

Also adding some FIXME:s related to what appears to be some short
comings related to how the register usage is calculated.

Differential Revision: https://reviews.llvm.org/D138342
  • Loading branch information
bjope committed Nov 20, 2022
1 parent 294fdd9 commit 1c308d6
Showing 1 changed file with 27 additions and 16 deletions.
43 changes: 27 additions & 16 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Expand Up @@ -5951,8 +5951,9 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) {
IntervalMap EndPoint;
// Saves the list of instruction indices that are used in the loop.
SmallPtrSet<Instruction *, 8> Ends;
// Saves the list of values that are used in the loop but are
// defined outside the loop, such as arguments and constants.
// Saves the list of values that are used in the loop but are defined outside
// the loop (not including non-instruction values such as arguments and
// constants).
SmallPtrSet<Value *, 8> LoopInvariants;

for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) {
Expand All @@ -5964,6 +5965,9 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) {
auto *Instr = dyn_cast<Instruction>(U);

// Ignore non-instruction values such as arguments, constants, etc.
// FIXME: Might need some motivation why these values are ignored. If
// for example an argument is used inside the loop it will increase the
// register pressure (so shouldn't we add it to LoopInvariants).
if (!Instr)
continue;

Expand Down Expand Up @@ -6019,14 +6023,19 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) {

// For each VF find the maximum usage of registers.
for (unsigned j = 0, e = VFs.size(); j < e; ++j) {
// Count the number of live intervals.
// Count the number of registers used, per register class, given all open
// intervals.
// Note that elements in this SmallMapVector will be default constructed
// as 0. So we can use "RegUsage[ClassID] += n" in the code below even if
// there is no previous entry for ClassID.
SmallMapVector<unsigned, unsigned, 4> RegUsage;

if (VFs[j].isScalar()) {
for (auto *Inst : OpenIntervals) {
unsigned ClassID = TTI.getRegisterClassForType(false, Inst->getType());
// If RegUsage[ClassID] doesn't exist, it will be default
// constructed as 0 before the addition
unsigned ClassID =
TTI.getRegisterClassForType(false, Inst->getType());
// FIXME: The target might use more than one register for the type
// even in the scalar case.
RegUsage[ClassID] += 1;
}
} else {
Expand All @@ -6036,14 +6045,14 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) {
if (VecValuesToIgnore.count(Inst))
continue;
if (isScalarAfterVectorization(Inst, VFs[j])) {
unsigned ClassID = TTI.getRegisterClassForType(false, Inst->getType());
// If RegUsage[ClassID] doesn't exist, it will be default
// constructed as 0 before the addition
unsigned ClassID =
TTI.getRegisterClassForType(false, Inst->getType());
// FIXME: The target might use more than one register for the type
// even in the scalar case.
RegUsage[ClassID] += 1;
} else {
unsigned ClassID = TTI.getRegisterClassForType(true, Inst->getType());
// If RegUsage[ClassID] doesn't exist, it will be default
// constructed as 0 before the addition
unsigned ClassID =
TTI.getRegisterClassForType(true, Inst->getType());
RegUsage[ClassID] += GetRegUsage(Inst->getType(), VFs[j]);
}
}
Expand All @@ -6063,17 +6072,19 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) {
}

for (unsigned i = 0, e = VFs.size(); i < e; ++i) {
// Note that elements in this SmallMapVector will be default constructed
// as 0. So we can use "Invariant[ClassID] += n" in the code below even if
// there is no previous entry for ClassID.
SmallMapVector<unsigned, unsigned, 4> Invariant;

for (auto *Inst : LoopInvariants) {
// FIXME: The target might use more than one register for the type
// even in the scalar case.
unsigned Usage =
VFs[i].isScalar() ? 1 : GetRegUsage(Inst->getType(), VFs[i]);
unsigned ClassID =
TTI.getRegisterClassForType(VFs[i].isVector(), Inst->getType());
if (Invariant.find(ClassID) == Invariant.end())
Invariant[ClassID] = Usage;
else
Invariant[ClassID] += Usage;
Invariant[ClassID] += Usage;
}

LLVM_DEBUG({
Expand Down

0 comments on commit 1c308d6

Please sign in to comment.