Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions llvm/include/llvm/CodeGen/MachineRegisterInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -634,10 +634,9 @@ class MachineRegisterInfo {
/// function. Writing to a constant register has no effect.
LLVM_ABI bool isConstantPhysReg(MCRegister PhysReg) const;

/// Get an iterator over the pressure sets affected by the given physical or
/// virtual register. If RegUnit is physical, it must be a register unit (from
/// MCRegUnitIterator).
PSetIterator getPressureSets(Register RegUnit) const;
/// Get an iterator over the pressure sets affected by the virtual register
/// or register unit.
PSetIterator getPressureSets(VirtRegOrUnit VRegOrUnit) const;

//===--------------------------------------------------------------------===//
// Virtual Register Info
Expand Down Expand Up @@ -1249,15 +1248,16 @@ class PSetIterator {
public:
PSetIterator() = default;

PSetIterator(Register RegUnit, const MachineRegisterInfo *MRI) {
PSetIterator(VirtRegOrUnit VRegOrUnit, const MachineRegisterInfo *MRI) {
const TargetRegisterInfo *TRI = MRI->getTargetRegisterInfo();
if (RegUnit.isVirtual()) {
const TargetRegisterClass *RC = MRI->getRegClass(RegUnit);
if (VRegOrUnit.isVirtualReg()) {
const TargetRegisterClass *RC =
MRI->getRegClass(VRegOrUnit.asVirtualReg());
PSet = TRI->getRegClassPressureSets(RC);
Weight = TRI->getRegClassWeight(RC).RegWeight;
} else {
PSet = TRI->getRegUnitPressureSets(RegUnit);
Weight = TRI->getRegUnitWeight(RegUnit);
PSet = TRI->getRegUnitPressureSets(VRegOrUnit.asMCRegUnit());
Weight = TRI->getRegUnitWeight(VRegOrUnit.asMCRegUnit());
}
if (*PSet == -1)
PSet = nullptr;
Expand All @@ -1278,8 +1278,8 @@ class PSetIterator {
};

inline PSetIterator
MachineRegisterInfo::getPressureSets(Register RegUnit) const {
return PSetIterator(RegUnit, this);
MachineRegisterInfo::getPressureSets(VirtRegOrUnit VRegOrUnit) const {
return PSetIterator(VRegOrUnit, this);
}

} // end namespace llvm
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/Register.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ class VirtRegOrUnit {
constexpr bool operator==(const VirtRegOrUnit &Other) const {
return VRegOrUnit == Other.VRegOrUnit;
}

constexpr bool operator<(const VirtRegOrUnit &Other) const {
return VRegOrUnit < Other.VRegOrUnit;
}
};

} // namespace llvm
Expand Down
51 changes: 28 additions & 23 deletions llvm/include/llvm/CodeGen/RegisterPressure.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ class MachineRegisterInfo;
class RegisterClassInfo;

struct VRegMaskOrUnit {
Register RegUnit; ///< Virtual register or register unit.
VirtRegOrUnit VRegOrUnit;
LaneBitmask LaneMask;

VRegMaskOrUnit(Register RegUnit, LaneBitmask LaneMask)
: RegUnit(RegUnit), LaneMask(LaneMask) {}
VRegMaskOrUnit(VirtRegOrUnit VRegOrUnit, LaneBitmask LaneMask)
: VRegOrUnit(VRegOrUnit), LaneMask(LaneMask) {}
};

/// Base class for register pressure results.
Expand Down Expand Up @@ -157,7 +157,7 @@ class PressureDiff {
const_iterator begin() const { return &PressureChanges[0]; }
const_iterator end() const { return &PressureChanges[MaxPSets]; }

LLVM_ABI void addPressureChange(Register RegUnit, bool IsDec,
LLVM_ABI void addPressureChange(VirtRegOrUnit VRegOrUnit, bool IsDec,
const MachineRegisterInfo *MRI);

LLVM_ABI void dump(const TargetRegisterInfo &TRI) const;
Expand Down Expand Up @@ -279,25 +279,25 @@ class LiveRegSet {
RegSet Regs;
unsigned NumRegUnits = 0u;

unsigned getSparseIndexFromReg(Register Reg) const {
if (Reg.isVirtual())
return Reg.virtRegIndex() + NumRegUnits;
assert(Reg < NumRegUnits);
return Reg.id();
unsigned getSparseIndexFromVirtRegOrUnit(VirtRegOrUnit VRegOrUnit) const {
if (VRegOrUnit.isVirtualReg())
return VRegOrUnit.asVirtualReg().virtRegIndex() + NumRegUnits;
assert(VRegOrUnit.asMCRegUnit() < NumRegUnits);
return VRegOrUnit.asMCRegUnit();
}

Register getRegFromSparseIndex(unsigned SparseIndex) const {
VirtRegOrUnit getVirtRegOrUnitFromSparseIndex(unsigned SparseIndex) const {
if (SparseIndex >= NumRegUnits)
return Register::index2VirtReg(SparseIndex - NumRegUnits);
return Register(SparseIndex);
return VirtRegOrUnit(Register::index2VirtReg(SparseIndex - NumRegUnits));
return VirtRegOrUnit(SparseIndex);
}

public:
LLVM_ABI void clear();
LLVM_ABI void init(const MachineRegisterInfo &MRI);

LaneBitmask contains(Register Reg) const {
unsigned SparseIndex = getSparseIndexFromReg(Reg);
LaneBitmask contains(VirtRegOrUnit VRegOrUnit) const {
unsigned SparseIndex = getSparseIndexFromVirtRegOrUnit(VRegOrUnit);
RegSet::const_iterator I = Regs.find(SparseIndex);
if (I == Regs.end())
return LaneBitmask::getNone();
Expand All @@ -307,7 +307,7 @@ class LiveRegSet {
/// Mark the \p Pair.LaneMask lanes of \p Pair.Reg as live.
/// Returns the previously live lanes of \p Pair.Reg.
LaneBitmask insert(VRegMaskOrUnit Pair) {
unsigned SparseIndex = getSparseIndexFromReg(Pair.RegUnit);
unsigned SparseIndex = getSparseIndexFromVirtRegOrUnit(Pair.VRegOrUnit);
auto InsertRes = Regs.insert(IndexMaskPair(SparseIndex, Pair.LaneMask));
if (!InsertRes.second) {
LaneBitmask PrevMask = InsertRes.first->LaneMask;
Expand All @@ -320,7 +320,7 @@ class LiveRegSet {
/// Clears the \p Pair.LaneMask lanes of \p Pair.Reg (mark them as dead).
/// Returns the previously live lanes of \p Pair.Reg.
LaneBitmask erase(VRegMaskOrUnit Pair) {
unsigned SparseIndex = getSparseIndexFromReg(Pair.RegUnit);
unsigned SparseIndex = getSparseIndexFromVirtRegOrUnit(Pair.VRegOrUnit);
RegSet::iterator I = Regs.find(SparseIndex);
if (I == Regs.end())
return LaneBitmask::getNone();
Expand All @@ -335,9 +335,9 @@ class LiveRegSet {

void appendTo(SmallVectorImpl<VRegMaskOrUnit> &To) const {
for (const IndexMaskPair &P : Regs) {
Register Reg = getRegFromSparseIndex(P.Index);
VirtRegOrUnit VRegOrUnit = getVirtRegOrUnitFromSparseIndex(P.Index);
if (P.LaneMask.any())
To.emplace_back(Reg, P.LaneMask);
To.emplace_back(VRegOrUnit, P.LaneMask);
}
}
};
Expand Down Expand Up @@ -541,9 +541,11 @@ class RegPressureTracker {

LLVM_ABI void dump() const;

LLVM_ABI void increaseRegPressure(Register RegUnit, LaneBitmask PreviousMask,
LLVM_ABI void increaseRegPressure(VirtRegOrUnit VRegOrUnit,
LaneBitmask PreviousMask,
LaneBitmask NewMask);
LLVM_ABI void decreaseRegPressure(Register RegUnit, LaneBitmask PreviousMask,
LLVM_ABI void decreaseRegPressure(VirtRegOrUnit VRegOrUnit,
LaneBitmask PreviousMask,
LaneBitmask NewMask);

protected:
Expand All @@ -565,9 +567,12 @@ class RegPressureTracker {
discoverLiveInOrOut(VRegMaskOrUnit Pair,
SmallVectorImpl<VRegMaskOrUnit> &LiveInOrOut);

LLVM_ABI LaneBitmask getLastUsedLanes(Register RegUnit, SlotIndex Pos) const;
LLVM_ABI LaneBitmask getLiveLanesAt(Register RegUnit, SlotIndex Pos) const;
LLVM_ABI LaneBitmask getLiveThroughAt(Register RegUnit, SlotIndex Pos) const;
LLVM_ABI LaneBitmask getLastUsedLanes(VirtRegOrUnit VRegOrUnit,
SlotIndex Pos) const;
LLVM_ABI LaneBitmask getLiveLanesAt(VirtRegOrUnit VRegOrUnit,
SlotIndex Pos) const;
LLVM_ABI LaneBitmask getLiveThroughAt(VirtRegOrUnit VRegOrUnit,
SlotIndex Pos) const;
};

LLVM_ABI void dumpRegSetPressure(ArrayRef<unsigned> SetPressure,
Expand Down
2 changes: 1 addition & 1 deletion llvm/include/llvm/CodeGen/TargetRegisterInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1450,7 +1450,7 @@ LLVM_ABI Printable printRegUnit(MCRegUnit Unit, const TargetRegisterInfo *TRI);

/// Create Printable object to print virtual registers and physical
/// registers on a \ref raw_ostream.
LLVM_ABI Printable printVRegOrUnit(unsigned VRegOrUnit,
LLVM_ABI Printable printVRegOrUnit(VirtRegOrUnit VRegOrUnit,
const TargetRegisterInfo *TRI);

/// Create Printable object to print register classes or register banks
Expand Down
66 changes: 48 additions & 18 deletions llvm/lib/CodeGen/MachinePipeliner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1509,7 +1509,11 @@ class HighRegisterPressureDetector {

void dumpPSet(Register Reg) const {
dbgs() << "Reg=" << printReg(Reg, TRI, 0, &MRI) << " PSet=";
for (auto PSetIter = MRI.getPressureSets(Reg); PSetIter.isValid();
// FIXME: The static_cast is a bug compensating bugs in the callers.
VirtRegOrUnit VRegOrUnit =
Reg.isVirtual() ? VirtRegOrUnit(Reg)
: VirtRegOrUnit(static_cast<MCRegUnit>(Reg.id()));
for (auto PSetIter = MRI.getPressureSets(VRegOrUnit); PSetIter.isValid();
++PSetIter) {
dbgs() << *PSetIter << ' ';
}
Expand All @@ -1518,15 +1522,19 @@ class HighRegisterPressureDetector {

void increaseRegisterPressure(std::vector<unsigned> &Pressure,
Register Reg) const {
auto PSetIter = MRI.getPressureSets(Reg);
// FIXME: The static_cast is a bug compensating bugs in the callers.
VirtRegOrUnit VRegOrUnit =
Reg.isVirtual() ? VirtRegOrUnit(Reg)
: VirtRegOrUnit(static_cast<MCRegUnit>(Reg.id()));
auto PSetIter = MRI.getPressureSets(VRegOrUnit);
unsigned Weight = PSetIter.getWeight();
for (; PSetIter.isValid(); ++PSetIter)
Pressure[*PSetIter] += Weight;
}

void decreaseRegisterPressure(std::vector<unsigned> &Pressure,
Register Reg) const {
auto PSetIter = MRI.getPressureSets(Reg);
auto PSetIter = MRI.getPressureSets(VirtRegOrUnit(Reg));
unsigned Weight = PSetIter.getWeight();
for (; PSetIter.isValid(); ++PSetIter) {
auto &P = Pressure[*PSetIter];
Expand Down Expand Up @@ -1559,7 +1567,11 @@ class HighRegisterPressureDetector {
if (MI.isDebugInstr())
continue;
for (auto &Use : ROMap[&MI].Uses) {
auto Reg = Use.RegUnit;
// FIXME: The static_cast is a bug.
Register Reg =
Use.VRegOrUnit.isVirtualReg()
? Use.VRegOrUnit.asVirtualReg()
: Register(static_cast<unsigned>(Use.VRegOrUnit.asMCRegUnit()));
// Ignore the variable that appears only on one side of phi instruction
// because it's used only at the first iteration.
if (MI.isPHI() && Reg != getLoopPhiReg(MI, OrigMBB))
Expand Down Expand Up @@ -1609,8 +1621,14 @@ class HighRegisterPressureDetector {
Register Reg = getLoopPhiReg(*MI, OrigMBB);
UpdateTargetRegs(Reg);
} else {
for (auto &Use : ROMap.find(MI)->getSecond().Uses)
UpdateTargetRegs(Use.RegUnit);
for (auto &Use : ROMap.find(MI)->getSecond().Uses) {
// FIXME: The static_cast is a bug.
Register Reg = Use.VRegOrUnit.isVirtualReg()
? Use.VRegOrUnit.asVirtualReg()
: Register(static_cast<unsigned>(
Use.VRegOrUnit.asMCRegUnit()));
UpdateTargetRegs(Reg);
}
}
}

Expand All @@ -1621,7 +1639,11 @@ class HighRegisterPressureDetector {
DenseMap<Register, MachineInstr *> LastUseMI;
for (MachineInstr *MI : llvm::reverse(OrderedInsts)) {
for (auto &Use : ROMap.find(MI)->getSecond().Uses) {
auto Reg = Use.RegUnit;
// FIXME: The static_cast is a bug.
Register Reg =
Use.VRegOrUnit.isVirtualReg()
? Use.VRegOrUnit.asVirtualReg()
: Register(static_cast<unsigned>(Use.VRegOrUnit.asMCRegUnit()));
if (!TargetRegs.contains(Reg))
continue;
auto [Ite, Inserted] = LastUseMI.try_emplace(Reg, MI);
Expand All @@ -1635,8 +1657,8 @@ class HighRegisterPressureDetector {
}

Instr2LastUsesTy LastUses;
for (auto &Entry : LastUseMI)
LastUses[Entry.second].insert(Entry.first);
for (auto [Reg, MI] : LastUseMI)
LastUses[MI].insert(Reg);
return LastUses;
}

Expand Down Expand Up @@ -1675,7 +1697,12 @@ class HighRegisterPressureDetector {
});

const auto InsertReg = [this, &CurSetPressure](RegSetTy &RegSet,
Register Reg) {
VirtRegOrUnit VRegOrUnit) {
// FIXME: The static_cast is a bug.
Register Reg =
VRegOrUnit.isVirtualReg()
? VRegOrUnit.asVirtualReg()
: Register(static_cast<unsigned>(VRegOrUnit.asMCRegUnit()));
if (!Reg.isValid() || isReservedRegister(Reg))
return;

Expand Down Expand Up @@ -1712,7 +1739,7 @@ class HighRegisterPressureDetector {
const unsigned Iter = I - Stage;

for (auto &Def : ROMap.find(MI)->getSecond().Defs)
InsertReg(LiveRegSets[Iter], Def.RegUnit);
InsertReg(LiveRegSets[Iter], Def.VRegOrUnit);

for (auto LastUse : LastUses[MI]) {
if (MI->isPHI()) {
Expand Down Expand Up @@ -2235,30 +2262,33 @@ static void computeLiveOuts(MachineFunction &MF, RegPressureTracker &RPTracker,
const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
MachineRegisterInfo &MRI = MF.getRegInfo();
SmallVector<VRegMaskOrUnit, 8> LiveOutRegs;
SmallSet<Register, 4> Uses;
SmallSet<VirtRegOrUnit, 4> Uses;
for (SUnit *SU : NS) {
const MachineInstr *MI = SU->getInstr();
if (MI->isPHI())
continue;
for (const MachineOperand &MO : MI->all_uses()) {
Register Reg = MO.getReg();
if (Reg.isVirtual())
Uses.insert(Reg);
Uses.insert(VirtRegOrUnit(Reg));
else if (MRI.isAllocatable(Reg))
Uses.insert_range(TRI->regunits(Reg.asMCReg()));
for (MCRegUnit Unit : TRI->regunits(Reg.asMCReg()))
Uses.insert(VirtRegOrUnit(Unit));
}
}
for (SUnit *SU : NS)
for (const MachineOperand &MO : SU->getInstr()->all_defs())
if (!MO.isDead()) {
Register Reg = MO.getReg();
if (Reg.isVirtual()) {
if (!Uses.count(Reg))
LiveOutRegs.emplace_back(Reg, LaneBitmask::getNone());
if (!Uses.count(VirtRegOrUnit(Reg)))
LiveOutRegs.emplace_back(VirtRegOrUnit(Reg),
LaneBitmask::getNone());
} else if (MRI.isAllocatable(Reg)) {
for (MCRegUnit Unit : TRI->regunits(Reg.asMCReg()))
if (!Uses.count(Unit))
LiveOutRegs.emplace_back(Unit, LaneBitmask::getNone());
if (!Uses.count(VirtRegOrUnit(Unit)))
LiveOutRegs.emplace_back(VirtRegOrUnit(Unit),
LaneBitmask::getNone());
}
}
RPTracker.addLiveRegs(LiveOutRegs);
Expand Down
14 changes: 7 additions & 7 deletions llvm/lib/CodeGen/MachineScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1580,10 +1580,10 @@ updateScheduledPressure(const SUnit *SU,
/// instruction.
void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
for (const VRegMaskOrUnit &P : LiveUses) {
Register Reg = P.RegUnit;
/// FIXME: Currently assuming single-use physregs.
if (!Reg.isVirtual())
if (!P.VRegOrUnit.isVirtualReg())
continue;
Register Reg = P.VRegOrUnit.asVirtualReg();

if (ShouldTrackLaneMasks) {
// If the register has just become live then other uses won't change
Expand All @@ -1599,7 +1599,7 @@ void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
continue;

PressureDiff &PDiff = getPressureDiff(&SU);
PDiff.addPressureChange(Reg, Decrement, &MRI);
PDiff.addPressureChange(VirtRegOrUnit(Reg), Decrement, &MRI);
if (llvm::any_of(PDiff, [](const PressureChange &Change) {
return Change.isValid();
}))
Expand All @@ -1611,7 +1611,7 @@ void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
}
} else {
assert(P.LaneMask.any());
LLVM_DEBUG(dbgs() << " LiveReg: " << printVRegOrUnit(Reg, TRI) << "\n");
LLVM_DEBUG(dbgs() << " LiveReg: " << printReg(Reg, TRI) << "\n");
// This may be called before CurrentBottom has been initialized. However,
// BotRPTracker must have a valid position. We want the value live into the
// instruction or live out of the block, so ask for the previous
Expand All @@ -1638,7 +1638,7 @@ void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
LI.Query(LIS->getInstructionIndex(*SU->getInstr()));
if (LRQ.valueIn() == VNI) {
PressureDiff &PDiff = getPressureDiff(SU);
PDiff.addPressureChange(Reg, true, &MRI);
PDiff.addPressureChange(VirtRegOrUnit(Reg), true, &MRI);
if (llvm::any_of(PDiff, [](const PressureChange &Change) {
return Change.isValid();
}))
Expand Down Expand Up @@ -1814,9 +1814,9 @@ unsigned ScheduleDAGMILive::computeCyclicCriticalPath() {
unsigned MaxCyclicLatency = 0;
// Visit each live out vreg def to find def/use pairs that cross iterations.
for (const VRegMaskOrUnit &P : RPTracker.getPressure().LiveOutRegs) {
Register Reg = P.RegUnit;
if (!Reg.isVirtual())
if (!P.VRegOrUnit.isVirtualReg())
continue;
Register Reg = P.VRegOrUnit.asVirtualReg();
const LiveInterval &LI = LIS->getInterval(Reg);
const VNInfo *DefVNI = LI.getVNInfoBefore(LIS->getMBBEndIdx(BB));
if (!DefVNI)
Expand Down
Loading
Loading