Skip to content

Commit

Permalink
[NFC] Use [MC]Register in RegAllocPBQP & RegisterCoalescer
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D90008
  • Loading branch information
jaingaurav committed Oct 27, 2020
1 parent 779deb9 commit 17cdba6
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 70 deletions.
2 changes: 1 addition & 1 deletion llvm/include/llvm/CodeGen/MachineRegisterInfo.h
Expand Up @@ -907,7 +907,7 @@ class MachineRegisterInfo {
///
/// Reserved registers may belong to an allocatable register class, but the
/// target has explicitly requested that they are not used.
bool isReserved(Register PhysReg) const {
bool isReserved(MCRegister PhysReg) const {
return getReservedRegs().test(PhysReg.id());
}

Expand Down
28 changes: 15 additions & 13 deletions llvm/include/llvm/CodeGen/RegAllocPBQP.h
Expand Up @@ -22,6 +22,8 @@
#include "llvm/CodeGen/PBQP/Math.h"
#include "llvm/CodeGen/PBQP/ReductionRules.h"
#include "llvm/CodeGen/PBQP/Solution.h"
#include "llvm/CodeGen/Register.h"
#include "llvm/MC/MCRegister.h"
#include "llvm/Support/ErrorHandling.h"
#include <algorithm>
#include <cassert>
Expand Down Expand Up @@ -96,13 +98,13 @@ class AllowedRegVector {
AllowedRegVector() = default;
AllowedRegVector(AllowedRegVector &&) = default;

AllowedRegVector(const std::vector<unsigned> &OptVec)
: NumOpts(OptVec.size()), Opts(new unsigned[NumOpts]) {
AllowedRegVector(const std::vector<MCRegister> &OptVec)
: NumOpts(OptVec.size()), Opts(new MCRegister[NumOpts]) {
std::copy(OptVec.begin(), OptVec.end(), Opts.get());
}

unsigned size() const { return NumOpts; }
unsigned operator[](size_t I) const { return Opts[I]; }
MCRegister operator[](size_t I) const { return Opts[I]; }

bool operator==(const AllowedRegVector &Other) const {
if (NumOpts != Other.NumOpts)
Expand All @@ -116,12 +118,12 @@ class AllowedRegVector {

private:
unsigned NumOpts = 0;
std::unique_ptr<unsigned[]> Opts;
std::unique_ptr<MCRegister[]> Opts;
};

inline hash_code hash_value(const AllowedRegVector &OptRegs) {
unsigned *OStart = OptRegs.Opts.get();
unsigned *OEnd = OptRegs.Opts.get() + OptRegs.NumOpts;
MCRegister *OStart = OptRegs.Opts.get();
MCRegister *OEnd = OptRegs.Opts.get() + OptRegs.NumOpts;
return hash_combine(OptRegs.NumOpts,
hash_combine_range(OStart, OEnd));
}
Expand All @@ -143,11 +145,11 @@ class GraphMetadata {
LiveIntervals &LIS;
MachineBlockFrequencyInfo &MBFI;

void setNodeIdForVReg(unsigned VReg, GraphBase::NodeId NId) {
VRegToNodeId[VReg] = NId;
void setNodeIdForVReg(Register VReg, GraphBase::NodeId NId) {
VRegToNodeId[VReg.id()] = NId;
}

GraphBase::NodeId getNodeIdForVReg(unsigned VReg) const {
GraphBase::NodeId getNodeIdForVReg(Register VReg) const {
auto VRegItr = VRegToNodeId.find(VReg);
if (VRegItr == VRegToNodeId.end())
return GraphBase::invalidNodeId();
Expand All @@ -159,7 +161,7 @@ class GraphMetadata {
}

private:
DenseMap<unsigned, GraphBase::NodeId> VRegToNodeId;
DenseMap<Register, GraphBase::NodeId> VRegToNodeId;
AllowedRegVecPool AllowedRegVecs;
};

Expand Down Expand Up @@ -197,8 +199,8 @@ class NodeMetadata {
NodeMetadata(NodeMetadata &&) = default;
NodeMetadata& operator=(NodeMetadata &&) = default;

void setVReg(unsigned VReg) { this->VReg = VReg; }
unsigned getVReg() const { return VReg; }
void setVReg(Register VReg) { this->VReg = VReg; }
Register getVReg() const { return VReg; }

void setAllowedRegs(GraphMetadata::AllowedRegVecRef AllowedRegs) {
this->AllowedRegs = std::move(AllowedRegs);
Expand Down Expand Up @@ -256,7 +258,7 @@ class NodeMetadata {
unsigned NumOpts = 0;
unsigned DeniedOpts = 0;
std::unique_ptr<unsigned[]> OptUnsafeEdges;
unsigned VReg = 0;
Register VReg;
GraphMetadata::AllowedRegVecRef AllowedRegs;

#ifndef NDEBUG
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/MC/MCRegister.h
Expand Up @@ -20,6 +20,7 @@ using MCPhysReg = uint16_t;

/// Wrapper class representing physical registers. Should be passed by value.
class MCRegister {
friend hash_code hash_value(const MCRegister &);
unsigned Reg;

public:
Expand Down Expand Up @@ -105,6 +106,9 @@ template<> struct DenseMapInfo<MCRegister> {
}
};

inline hash_code hash_value(const MCRegister &Reg) {
return hash_value(Reg.id());
}
}

#endif // ifndef LLVM_MC_REGISTER_H
47 changes: 24 additions & 23 deletions llvm/lib/CodeGen/RegAllocPBQP.cpp
Expand Up @@ -147,7 +147,7 @@ class RegAllocPBQP : public MachineFunctionPass {
using AllowedSetMap = std::vector<AllowedSet>;
using RegPair = std::pair<unsigned, unsigned>;
using CoalesceMap = std::map<RegPair, PBQP::PBQPNum>;
using RegSet = std::set<unsigned>;
using RegSet = std::set<Register>;

char *customPassID;

Expand Down Expand Up @@ -331,7 +331,7 @@ class Interference : public PBQPRAConstraint {

// Start by building the inactive set.
for (auto NId : G.nodeIds()) {
unsigned VReg = G.getNodeMetadata(NId).getVReg();
Register VReg = G.getNodeMetadata(NId).getVReg();
LiveInterval &LI = LIS.getInterval(VReg);
assert(!LI.empty() && "PBQP graph contains node for empty interval");
Inactive.push(std::make_tuple(&LI, 0, NId));
Expand Down Expand Up @@ -413,9 +413,9 @@ class Interference : public PBQPRAConstraint {
PBQPRAGraph::RawMatrix M(NRegs.size() + 1, MRegs.size() + 1, 0);
bool NodesInterfere = false;
for (unsigned I = 0; I != NRegs.size(); ++I) {
unsigned PRegN = NRegs[I];
MCRegister PRegN = NRegs[I];
for (unsigned J = 0; J != MRegs.size(); ++J) {
unsigned PRegM = MRegs[J];
MCRegister PRegM = MRegs[J];
if (TRI.regsOverlap(PRegN, PRegM)) {
M[I + 1][J + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
NodesInterfere = true;
Expand Down Expand Up @@ -448,8 +448,8 @@ class Coalescing : public PBQPRAConstraint {
if (!CP.setRegisters(&MI) || CP.getSrcReg() == CP.getDstReg())
continue;

unsigned DstReg = CP.getDstReg();
unsigned SrcReg = CP.getSrcReg();
Register DstReg = CP.getDstReg();
Register SrcReg = CP.getSrcReg();

PBQP::PBQPNum CBenefit = MBFI.getBlockFreqRelativeToEntryBlock(&MBB);

Expand All @@ -463,7 +463,7 @@ class Coalescing : public PBQPRAConstraint {
G.getNodeMetadata(NId).getAllowedRegs();

unsigned PRegOpt = 0;
while (PRegOpt < Allowed.size() && Allowed[PRegOpt] != DstReg)
while (PRegOpt < Allowed.size() && Allowed[PRegOpt].id() != DstReg)
++PRegOpt;

if (PRegOpt < Allowed.size()) {
Expand Down Expand Up @@ -508,9 +508,9 @@ class Coalescing : public PBQPRAConstraint {
assert(CostMat.getRows() == Allowed1.size() + 1 && "Size mismatch.");
assert(CostMat.getCols() == Allowed2.size() + 1 && "Size mismatch.");
for (unsigned I = 0; I != Allowed1.size(); ++I) {
unsigned PReg1 = Allowed1[I];
MCRegister PReg1 = Allowed1[I];
for (unsigned J = 0; J != Allowed2.size(); ++J) {
unsigned PReg2 = Allowed2[J];
MCRegister PReg2 = Allowed2[J];
if (PReg1 == PReg2)
CostMat[I + 1][J + 1] -= Benefit;
}
Expand Down Expand Up @@ -571,18 +571,19 @@ void RegAllocPBQP::findVRegIntervalsToAlloc(const MachineFunction &MF,

// Iterate over all live ranges.
for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
unsigned Reg = Register::index2VirtReg(I);
Register Reg = Register::index2VirtReg(I);
if (MRI.reg_nodbg_empty(Reg))
continue;
VRegsToAlloc.insert(Reg);
}
}

static bool isACalleeSavedRegister(unsigned reg, const TargetRegisterInfo &TRI,
static bool isACalleeSavedRegister(MCRegister Reg,
const TargetRegisterInfo &TRI,
const MachineFunction &MF) {
const MCPhysReg *CSR = MF.getRegInfo().getCalleeSavedRegs();
for (unsigned i = 0; CSR[i] != 0; ++i)
if (TRI.regsOverlap(reg, CSR[i]))
if (TRI.regsOverlap(Reg, CSR[i]))
return true;
return false;
}
Expand All @@ -596,12 +597,12 @@ void RegAllocPBQP::initializeGraph(PBQPRAGraph &G, VirtRegMap &VRM,
const TargetRegisterInfo &TRI =
*G.getMetadata().MF.getSubtarget().getRegisterInfo();

std::vector<unsigned> Worklist(VRegsToAlloc.begin(), VRegsToAlloc.end());
std::vector<Register> Worklist(VRegsToAlloc.begin(), VRegsToAlloc.end());

std::map<unsigned, std::vector<unsigned>> VRegAllowedMap;
std::map<Register, std::vector<MCRegister>> VRegAllowedMap;

while (!Worklist.empty()) {
unsigned VReg = Worklist.back();
Register VReg = Worklist.back();
Worklist.pop_back();

LiveInterval &VRegLI = LIS.getInterval(VReg);
Expand All @@ -621,10 +622,10 @@ void RegAllocPBQP::initializeGraph(PBQPRAGraph &G, VirtRegMap &VRM,
LIS.checkRegMaskInterference(VRegLI, RegMaskOverlaps);

// Compute an initial allowed set for the current vreg.
std::vector<unsigned> VRegAllowed;
std::vector<MCRegister> VRegAllowed;
ArrayRef<MCPhysReg> RawPRegOrder = TRC->getRawAllocationOrder(MF);
for (unsigned I = 0; I != RawPRegOrder.size(); ++I) {
unsigned PReg = RawPRegOrder[I];
MCRegister PReg(RawPRegOrder[I]);
if (MRI.isReserved(PReg))
continue;

Expand Down Expand Up @@ -731,11 +732,11 @@ bool RegAllocPBQP::mapPBQPToRegAlloc(const PBQPRAGraph &G,
// Iterate over the nodes mapping the PBQP solution to a register
// assignment.
for (auto NId : G.nodeIds()) {
unsigned VReg = G.getNodeMetadata(NId).getVReg();
unsigned AllocOption = Solution.getSelection(NId);
Register VReg = G.getNodeMetadata(NId).getVReg();
unsigned AllocOpt = Solution.getSelection(NId);

if (AllocOption != PBQP::RegAlloc::getSpillOptionIdx()) {
unsigned PReg = G.getNodeMetadata(NId).getAllowedRegs()[AllocOption - 1];
if (AllocOpt != PBQP::RegAlloc::getSpillOptionIdx()) {
MCRegister PReg = G.getNodeMetadata(NId).getAllowedRegs()[AllocOpt - 1];
LLVM_DEBUG(dbgs() << "VREG " << printReg(VReg, &TRI) << " -> "
<< TRI.getName(PReg) << "\n");
assert(PReg != 0 && "Invalid preg selected.");
Expand Down Expand Up @@ -763,7 +764,7 @@ void RegAllocPBQP::finalizeAlloc(MachineFunction &MF,
I != E; ++I) {
LiveInterval &LI = LIS.getInterval(*I);

unsigned PReg = MRI.getSimpleHint(LI.reg());
Register PReg = MRI.getSimpleHint(LI.reg());

if (PReg == 0) {
const TargetRegisterClass &RC = *MRI.getRegClass(LI.reg());
Expand Down Expand Up @@ -884,7 +885,7 @@ static Printable PrintNodeInfo(PBQP::RegAlloc::PBQPRAGraph::NodeId NId,
return Printable([NId, &G](raw_ostream &OS) {
const MachineRegisterInfo &MRI = G.getMetadata().MF.getRegInfo();
const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
unsigned VReg = G.getNodeMetadata(NId).getVReg();
Register VReg = G.getNodeMetadata(NId).getVReg();
const char *RegClassName = TRI->getRegClassName(MRI.getRegClass(VReg));
OS << NId << " (" << RegClassName << ':' << printReg(VReg, TRI) << ')';
});
Expand Down

0 comments on commit 17cdba6

Please sign in to comment.