Skip to content

Commit a67a46f

Browse files
committed
[CodeGen] Use VirtRegOrUnit where appropriate (NFCI)
Use it in `printVRegOrUnit()`, `getPressureSets()`/`PSetIterator`, and in functions/classes dealing with register pressure. Static type checking revealed several bugs, mainly in MachinePipeliner. I'm not very familiar with this pass, so I left a bunch of FIXMEs. There is one bug in `findUseBetween()` in RegisterPressure.cpp, also annotated with a FIXME.
1 parent b9eb974 commit a67a46f

File tree

12 files changed

+292
-246
lines changed

12 files changed

+292
-246
lines changed

llvm/include/llvm/CodeGen/MachineRegisterInfo.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ class MachineRegisterInfo {
637637
/// Get an iterator over the pressure sets affected by the given physical or
638638
/// virtual register. If RegUnit is physical, it must be a register unit (from
639639
/// MCRegUnitIterator).
640-
PSetIterator getPressureSets(Register RegUnit) const;
640+
PSetIterator getPressureSets(VirtRegOrUnit VRegOrUnit) const;
641641

642642
//===--------------------------------------------------------------------===//
643643
// Virtual Register Info
@@ -1249,15 +1249,16 @@ class PSetIterator {
12491249
public:
12501250
PSetIterator() = default;
12511251

1252-
PSetIterator(Register RegUnit, const MachineRegisterInfo *MRI) {
1252+
PSetIterator(VirtRegOrUnit VRegOrUnit, const MachineRegisterInfo *MRI) {
12531253
const TargetRegisterInfo *TRI = MRI->getTargetRegisterInfo();
1254-
if (RegUnit.isVirtual()) {
1255-
const TargetRegisterClass *RC = MRI->getRegClass(RegUnit);
1254+
if (VRegOrUnit.isVirtualReg()) {
1255+
const TargetRegisterClass *RC =
1256+
MRI->getRegClass(VRegOrUnit.asVirtualReg());
12561257
PSet = TRI->getRegClassPressureSets(RC);
12571258
Weight = TRI->getRegClassWeight(RC).RegWeight;
12581259
} else {
1259-
PSet = TRI->getRegUnitPressureSets(RegUnit);
1260-
Weight = TRI->getRegUnitWeight(RegUnit);
1260+
PSet = TRI->getRegUnitPressureSets(VRegOrUnit.asMCRegUnit());
1261+
Weight = TRI->getRegUnitWeight(VRegOrUnit.asMCRegUnit());
12611262
}
12621263
if (*PSet == -1)
12631264
PSet = nullptr;
@@ -1278,8 +1279,8 @@ class PSetIterator {
12781279
};
12791280

12801281
inline PSetIterator
1281-
MachineRegisterInfo::getPressureSets(Register RegUnit) const {
1282-
return PSetIterator(RegUnit, this);
1282+
MachineRegisterInfo::getPressureSets(VirtRegOrUnit VRegOrUnit) const {
1283+
return PSetIterator(VRegOrUnit, this);
12831284
}
12841285

12851286
} // end namespace llvm

llvm/include/llvm/CodeGen/Register.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,10 @@ class VirtRegOrUnit {
206206
constexpr bool operator==(const VirtRegOrUnit &Other) const {
207207
return VRegOrUnit == Other.VRegOrUnit;
208208
}
209+
210+
constexpr bool operator<(const VirtRegOrUnit &Other) const {
211+
return VRegOrUnit < Other.VRegOrUnit;
212+
}
209213
};
210214

211215
} // namespace llvm

llvm/include/llvm/CodeGen/RegisterPressure.h

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ class MachineRegisterInfo;
3737
class RegisterClassInfo;
3838

3939
struct VRegMaskOrUnit {
40-
Register RegUnit; ///< Virtual register or register unit.
40+
VirtRegOrUnit VRegOrUnit;
4141
LaneBitmask LaneMask;
4242

43-
VRegMaskOrUnit(Register RegUnit, LaneBitmask LaneMask)
44-
: RegUnit(RegUnit), LaneMask(LaneMask) {}
43+
VRegMaskOrUnit(VirtRegOrUnit VRegOrUnit, LaneBitmask LaneMask)
44+
: VRegOrUnit(VRegOrUnit), LaneMask(LaneMask) {}
4545
};
4646

4747
/// Base class for register pressure results.
@@ -157,7 +157,7 @@ class PressureDiff {
157157
const_iterator begin() const { return &PressureChanges[0]; }
158158
const_iterator end() const { return &PressureChanges[MaxPSets]; }
159159

160-
LLVM_ABI void addPressureChange(Register RegUnit, bool IsDec,
160+
LLVM_ABI void addPressureChange(VirtRegOrUnit VRegOrUnit, bool IsDec,
161161
const MachineRegisterInfo *MRI);
162162

163163
LLVM_ABI void dump(const TargetRegisterInfo &TRI) const;
@@ -279,25 +279,25 @@ class LiveRegSet {
279279
RegSet Regs;
280280
unsigned NumRegUnits = 0u;
281281

282-
unsigned getSparseIndexFromReg(Register Reg) const {
283-
if (Reg.isVirtual())
284-
return Reg.virtRegIndex() + NumRegUnits;
285-
assert(Reg < NumRegUnits);
286-
return Reg.id();
282+
unsigned getSparseIndexFromVirtRegOrUnit(VirtRegOrUnit VRegOrUnit) const {
283+
if (VRegOrUnit.isVirtualReg())
284+
return VRegOrUnit.asVirtualReg().virtRegIndex() + NumRegUnits;
285+
assert(VRegOrUnit.asMCRegUnit() < NumRegUnits);
286+
return VRegOrUnit.asMCRegUnit();
287287
}
288288

289-
Register getRegFromSparseIndex(unsigned SparseIndex) const {
289+
VirtRegOrUnit getVirtRegOrUnitFromSparseIndex(unsigned SparseIndex) const {
290290
if (SparseIndex >= NumRegUnits)
291-
return Register::index2VirtReg(SparseIndex - NumRegUnits);
292-
return Register(SparseIndex);
291+
return VirtRegOrUnit(Register::index2VirtReg(SparseIndex - NumRegUnits));
292+
return VirtRegOrUnit(SparseIndex);
293293
}
294294

295295
public:
296296
LLVM_ABI void clear();
297297
LLVM_ABI void init(const MachineRegisterInfo &MRI);
298298

299-
LaneBitmask contains(Register Reg) const {
300-
unsigned SparseIndex = getSparseIndexFromReg(Reg);
299+
LaneBitmask contains(VirtRegOrUnit VRegOrUnit) const {
300+
unsigned SparseIndex = getSparseIndexFromVirtRegOrUnit(VRegOrUnit);
301301
RegSet::const_iterator I = Regs.find(SparseIndex);
302302
if (I == Regs.end())
303303
return LaneBitmask::getNone();
@@ -307,7 +307,7 @@ class LiveRegSet {
307307
/// Mark the \p Pair.LaneMask lanes of \p Pair.Reg as live.
308308
/// Returns the previously live lanes of \p Pair.Reg.
309309
LaneBitmask insert(VRegMaskOrUnit Pair) {
310-
unsigned SparseIndex = getSparseIndexFromReg(Pair.RegUnit);
310+
unsigned SparseIndex = getSparseIndexFromVirtRegOrUnit(Pair.VRegOrUnit);
311311
auto InsertRes = Regs.insert(IndexMaskPair(SparseIndex, Pair.LaneMask));
312312
if (!InsertRes.second) {
313313
LaneBitmask PrevMask = InsertRes.first->LaneMask;
@@ -320,7 +320,7 @@ class LiveRegSet {
320320
/// Clears the \p Pair.LaneMask lanes of \p Pair.Reg (mark them as dead).
321321
/// Returns the previously live lanes of \p Pair.Reg.
322322
LaneBitmask erase(VRegMaskOrUnit Pair) {
323-
unsigned SparseIndex = getSparseIndexFromReg(Pair.RegUnit);
323+
unsigned SparseIndex = getSparseIndexFromVirtRegOrUnit(Pair.VRegOrUnit);
324324
RegSet::iterator I = Regs.find(SparseIndex);
325325
if (I == Regs.end())
326326
return LaneBitmask::getNone();
@@ -335,9 +335,9 @@ class LiveRegSet {
335335

336336
void appendTo(SmallVectorImpl<VRegMaskOrUnit> &To) const {
337337
for (const IndexMaskPair &P : Regs) {
338-
Register Reg = getRegFromSparseIndex(P.Index);
338+
VirtRegOrUnit VRegOrUnit = getVirtRegOrUnitFromSparseIndex(P.Index);
339339
if (P.LaneMask.any())
340-
To.emplace_back(Reg, P.LaneMask);
340+
To.emplace_back(VRegOrUnit, P.LaneMask);
341341
}
342342
}
343343
};
@@ -541,9 +541,11 @@ class RegPressureTracker {
541541

542542
LLVM_ABI void dump() const;
543543

544-
LLVM_ABI void increaseRegPressure(Register RegUnit, LaneBitmask PreviousMask,
544+
LLVM_ABI void increaseRegPressure(VirtRegOrUnit VRegOrUnit,
545+
LaneBitmask PreviousMask,
545546
LaneBitmask NewMask);
546-
LLVM_ABI void decreaseRegPressure(Register RegUnit, LaneBitmask PreviousMask,
547+
LLVM_ABI void decreaseRegPressure(VirtRegOrUnit VRegOrUnit,
548+
LaneBitmask PreviousMask,
547549
LaneBitmask NewMask);
548550

549551
protected:
@@ -565,9 +567,12 @@ class RegPressureTracker {
565567
discoverLiveInOrOut(VRegMaskOrUnit Pair,
566568
SmallVectorImpl<VRegMaskOrUnit> &LiveInOrOut);
567569

568-
LLVM_ABI LaneBitmask getLastUsedLanes(Register RegUnit, SlotIndex Pos) const;
569-
LLVM_ABI LaneBitmask getLiveLanesAt(Register RegUnit, SlotIndex Pos) const;
570-
LLVM_ABI LaneBitmask getLiveThroughAt(Register RegUnit, SlotIndex Pos) const;
570+
LLVM_ABI LaneBitmask getLastUsedLanes(VirtRegOrUnit VRegOrUnit,
571+
SlotIndex Pos) const;
572+
LLVM_ABI LaneBitmask getLiveLanesAt(VirtRegOrUnit VRegOrUnit,
573+
SlotIndex Pos) const;
574+
LLVM_ABI LaneBitmask getLiveThroughAt(VirtRegOrUnit VRegOrUnit,
575+
SlotIndex Pos) const;
571576
};
572577

573578
LLVM_ABI void dumpRegSetPressure(ArrayRef<unsigned> SetPressure,

llvm/include/llvm/CodeGen/TargetRegisterInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1450,7 +1450,7 @@ LLVM_ABI Printable printRegUnit(MCRegUnit Unit, const TargetRegisterInfo *TRI);
14501450

14511451
/// Create Printable object to print virtual registers and physical
14521452
/// registers on a \ref raw_ostream.
1453-
LLVM_ABI Printable printVRegOrUnit(unsigned VRegOrUnit,
1453+
LLVM_ABI Printable printVRegOrUnit(VirtRegOrUnit VRegOrUnit,
14541454
const TargetRegisterInfo *TRI);
14551455

14561456
/// Create Printable object to print register classes or register banks

llvm/lib/CodeGen/MachinePipeliner.cpp

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,7 +1509,10 @@ class HighRegisterPressureDetector {
15091509

15101510
void dumpPSet(Register Reg) const {
15111511
dbgs() << "Reg=" << printReg(Reg, TRI, 0, &MRI) << " PSet=";
1512-
for (auto PSetIter = MRI.getPressureSets(Reg); PSetIter.isValid();
1512+
// FIXME: The static cast is a bug compensating bugs in the callers.
1513+
VirtRegOrUnit VRegOrUnit =
1514+
Reg.isVirtual() ? VirtRegOrUnit(Reg) : VirtRegOrUnit(Reg.id());
1515+
for (auto PSetIter = MRI.getPressureSets(VRegOrUnit); PSetIter.isValid();
15131516
++PSetIter) {
15141517
dbgs() << *PSetIter << ' ';
15151518
}
@@ -1518,15 +1521,18 @@ class HighRegisterPressureDetector {
15181521

15191522
void increaseRegisterPressure(std::vector<unsigned> &Pressure,
15201523
Register Reg) const {
1521-
auto PSetIter = MRI.getPressureSets(Reg);
1524+
// FIXME: The static cast is a bug compensating bugs in the callers.
1525+
VirtRegOrUnit VRegOrUnit =
1526+
Reg.isVirtual() ? VirtRegOrUnit(Reg) : VirtRegOrUnit(Reg.id());
1527+
auto PSetIter = MRI.getPressureSets(VRegOrUnit);
15221528
unsigned Weight = PSetIter.getWeight();
15231529
for (; PSetIter.isValid(); ++PSetIter)
15241530
Pressure[*PSetIter] += Weight;
15251531
}
15261532

15271533
void decreaseRegisterPressure(std::vector<unsigned> &Pressure,
15281534
Register Reg) const {
1529-
auto PSetIter = MRI.getPressureSets(Reg);
1535+
auto PSetIter = MRI.getPressureSets(VirtRegOrUnit(Reg));
15301536
unsigned Weight = PSetIter.getWeight();
15311537
for (; PSetIter.isValid(); ++PSetIter) {
15321538
auto &P = Pressure[*PSetIter];
@@ -1559,7 +1565,11 @@ class HighRegisterPressureDetector {
15591565
if (MI.isDebugInstr())
15601566
continue;
15611567
for (auto &Use : ROMap[&MI].Uses) {
1562-
auto Reg = Use.RegUnit;
1568+
// FIXME: The static_cast is a bug.
1569+
Register Reg =
1570+
Use.VRegOrUnit.isVirtualReg()
1571+
? Use.VRegOrUnit.asVirtualReg()
1572+
: static_cast<Register>(Use.VRegOrUnit.asMCRegUnit());
15631573
// Ignore the variable that appears only on one side of phi instruction
15641574
// because it's used only at the first iteration.
15651575
if (MI.isPHI() && Reg != getLoopPhiReg(MI, OrigMBB))
@@ -1609,8 +1619,14 @@ class HighRegisterPressureDetector {
16091619
Register Reg = getLoopPhiReg(*MI, OrigMBB);
16101620
UpdateTargetRegs(Reg);
16111621
} else {
1612-
for (auto &Use : ROMap.find(MI)->getSecond().Uses)
1613-
UpdateTargetRegs(Use.RegUnit);
1622+
for (auto &Use : ROMap.find(MI)->getSecond().Uses) {
1623+
// FIXME: The static_cast is a bug.
1624+
Register Reg =
1625+
Use.VRegOrUnit.isVirtualReg()
1626+
? Use.VRegOrUnit.asVirtualReg()
1627+
: static_cast<Register>(Use.VRegOrUnit.asMCRegUnit());
1628+
UpdateTargetRegs(Reg);
1629+
}
16141630
}
16151631
}
16161632

@@ -1621,7 +1637,11 @@ class HighRegisterPressureDetector {
16211637
DenseMap<Register, MachineInstr *> LastUseMI;
16221638
for (MachineInstr *MI : llvm::reverse(OrderedInsts)) {
16231639
for (auto &Use : ROMap.find(MI)->getSecond().Uses) {
1624-
auto Reg = Use.RegUnit;
1640+
// FIXME: The static_cast is a bug.
1641+
Register Reg =
1642+
Use.VRegOrUnit.isVirtualReg()
1643+
? Use.VRegOrUnit.asVirtualReg()
1644+
: static_cast<Register>(Use.VRegOrUnit.asMCRegUnit());
16251645
if (!TargetRegs.contains(Reg))
16261646
continue;
16271647
auto [Ite, Inserted] = LastUseMI.try_emplace(Reg, MI);
@@ -1635,8 +1655,8 @@ class HighRegisterPressureDetector {
16351655
}
16361656

16371657
Instr2LastUsesTy LastUses;
1638-
for (auto &Entry : LastUseMI)
1639-
LastUses[Entry.second].insert(Entry.first);
1658+
for (auto [Reg, MI] : LastUseMI)
1659+
LastUses[MI].insert(Reg);
16401660
return LastUses;
16411661
}
16421662

@@ -1675,7 +1695,11 @@ class HighRegisterPressureDetector {
16751695
});
16761696

16771697
const auto InsertReg = [this, &CurSetPressure](RegSetTy &RegSet,
1678-
Register Reg) {
1698+
VirtRegOrUnit VRegOrUnit) {
1699+
// FIXME: The static_cast is a bug.
1700+
Register Reg = VRegOrUnit.isVirtualReg()
1701+
? VRegOrUnit.asVirtualReg()
1702+
: static_cast<Register>(VRegOrUnit.asMCRegUnit());
16791703
if (!Reg.isValid() || isReservedRegister(Reg))
16801704
return;
16811705

@@ -1712,7 +1736,7 @@ class HighRegisterPressureDetector {
17121736
const unsigned Iter = I - Stage;
17131737

17141738
for (auto &Def : ROMap.find(MI)->getSecond().Defs)
1715-
InsertReg(LiveRegSets[Iter], Def.RegUnit);
1739+
InsertReg(LiveRegSets[Iter], Def.VRegOrUnit);
17161740

17171741
for (auto LastUse : LastUses[MI]) {
17181742
if (MI->isPHI()) {
@@ -2235,30 +2259,33 @@ static void computeLiveOuts(MachineFunction &MF, RegPressureTracker &RPTracker,
22352259
const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
22362260
MachineRegisterInfo &MRI = MF.getRegInfo();
22372261
SmallVector<VRegMaskOrUnit, 8> LiveOutRegs;
2238-
SmallSet<Register, 4> Uses;
2262+
SmallSet<VirtRegOrUnit, 4> Uses;
22392263
for (SUnit *SU : NS) {
22402264
const MachineInstr *MI = SU->getInstr();
22412265
if (MI->isPHI())
22422266
continue;
22432267
for (const MachineOperand &MO : MI->all_uses()) {
22442268
Register Reg = MO.getReg();
22452269
if (Reg.isVirtual())
2246-
Uses.insert(Reg);
2270+
Uses.insert(VirtRegOrUnit(Reg));
22472271
else if (MRI.isAllocatable(Reg))
2248-
Uses.insert_range(TRI->regunits(Reg.asMCReg()));
2272+
for (MCRegUnit Unit : TRI->regunits(Reg.asMCReg()))
2273+
Uses.insert(VirtRegOrUnit(Unit));
22492274
}
22502275
}
22512276
for (SUnit *SU : NS)
22522277
for (const MachineOperand &MO : SU->getInstr()->all_defs())
22532278
if (!MO.isDead()) {
22542279
Register Reg = MO.getReg();
22552280
if (Reg.isVirtual()) {
2256-
if (!Uses.count(Reg))
2257-
LiveOutRegs.emplace_back(Reg, LaneBitmask::getNone());
2281+
if (!Uses.count(VirtRegOrUnit(Reg)))
2282+
LiveOutRegs.emplace_back(VirtRegOrUnit(Reg),
2283+
LaneBitmask::getNone());
22582284
} else if (MRI.isAllocatable(Reg)) {
22592285
for (MCRegUnit Unit : TRI->regunits(Reg.asMCReg()))
2260-
if (!Uses.count(Unit))
2261-
LiveOutRegs.emplace_back(Unit, LaneBitmask::getNone());
2286+
if (!Uses.count(VirtRegOrUnit(Unit)))
2287+
LiveOutRegs.emplace_back(VirtRegOrUnit(Unit),
2288+
LaneBitmask::getNone());
22622289
}
22632290
}
22642291
RPTracker.addLiveRegs(LiveOutRegs);

llvm/lib/CodeGen/MachineScheduler.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,10 +1580,10 @@ updateScheduledPressure(const SUnit *SU,
15801580
/// instruction.
15811581
void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
15821582
for (const VRegMaskOrUnit &P : LiveUses) {
1583-
Register Reg = P.RegUnit;
15841583
/// FIXME: Currently assuming single-use physregs.
1585-
if (!Reg.isVirtual())
1584+
if (!P.VRegOrUnit.isVirtualReg())
15861585
continue;
1586+
Register Reg = P.VRegOrUnit.asVirtualReg();
15871587

15881588
if (ShouldTrackLaneMasks) {
15891589
// If the register has just become live then other uses won't change
@@ -1599,7 +1599,7 @@ void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
15991599
continue;
16001600

16011601
PressureDiff &PDiff = getPressureDiff(&SU);
1602-
PDiff.addPressureChange(Reg, Decrement, &MRI);
1602+
PDiff.addPressureChange(VirtRegOrUnit(Reg), Decrement, &MRI);
16031603
if (llvm::any_of(PDiff, [](const PressureChange &Change) {
16041604
return Change.isValid();
16051605
}))
@@ -1611,7 +1611,7 @@ void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
16111611
}
16121612
} else {
16131613
assert(P.LaneMask.any());
1614-
LLVM_DEBUG(dbgs() << " LiveReg: " << printVRegOrUnit(Reg, TRI) << "\n");
1614+
LLVM_DEBUG(dbgs() << " LiveReg: " << printReg(Reg, TRI) << "\n");
16151615
// This may be called before CurrentBottom has been initialized. However,
16161616
// BotRPTracker must have a valid position. We want the value live into the
16171617
// instruction or live out of the block, so ask for the previous
@@ -1638,7 +1638,7 @@ void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
16381638
LI.Query(LIS->getInstructionIndex(*SU->getInstr()));
16391639
if (LRQ.valueIn() == VNI) {
16401640
PressureDiff &PDiff = getPressureDiff(SU);
1641-
PDiff.addPressureChange(Reg, true, &MRI);
1641+
PDiff.addPressureChange(VirtRegOrUnit(Reg), true, &MRI);
16421642
if (llvm::any_of(PDiff, [](const PressureChange &Change) {
16431643
return Change.isValid();
16441644
}))
@@ -1814,9 +1814,9 @@ unsigned ScheduleDAGMILive::computeCyclicCriticalPath() {
18141814
unsigned MaxCyclicLatency = 0;
18151815
// Visit each live out vreg def to find def/use pairs that cross iterations.
18161816
for (const VRegMaskOrUnit &P : RPTracker.getPressure().LiveOutRegs) {
1817-
Register Reg = P.RegUnit;
1818-
if (!Reg.isVirtual())
1817+
if (!P.VRegOrUnit.isVirtualReg())
18191818
continue;
1819+
Register Reg = P.VRegOrUnit.asVirtualReg();
18201820
const LiveInterval &LI = LIS->getInterval(Reg);
18211821
const VNInfo *DefVNI = LI.getVNInfoBefore(LIS->getMBBEndIdx(BB));
18221822
if (!DefVNI)

0 commit comments

Comments
 (0)