Skip to content

Commit

Permalink
[AMDGPU] GCNRegPressurePrinter pass to print GCNRegPressure values fo…
Browse files Browse the repository at this point in the history
…r testing. (#70031)

Using GCNDownwardRPTracker or GCNUpwardRPTracker the pass collects register pressure values for a function and prints these values next to instructions. Output can be used to generate Filecheck rules in mir tests.
  • Loading branch information
vpykhtin committed Nov 1, 2023
1 parent e69e066 commit e808f8a
Show file tree
Hide file tree
Showing 5 changed files with 625 additions and 19 deletions.
3 changes: 3 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ extern char &AMDGPUImageIntrinsicOptimizerID;
void initializeAMDGPUPerfHintAnalysisPass(PassRegistry &);
extern char &AMDGPUPerfHintAnalysisID;

void initializeGCNRegPressurePrinterPass(PassRegistry &);
extern char &GCNRegPressurePrinterID;

// Passes common to R600 and SI
FunctionPass *createAMDGPUPromoteAlloca();
void initializeAMDGPUPromoteAllocaPass(PassRegistry&);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeAMDGPUTarget() {
initializeGCNPreRAOptimizationsPass(*PR);
initializeGCNPreRALongBranchRegPass(*PR);
initializeGCNRewritePartialRegUsesPass(*PR);
initializeGCNRegPressurePrinterPass(*PR);
}

static std::unique_ptr<TargetLoweringObjectFile> createTLOF(const Triple &TT) {
Expand Down
144 changes: 130 additions & 14 deletions llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//

#include "GCNRegPressure.h"
#include "AMDGPU.h"
#include "llvm/CodeGen/RegisterPressure.h"

using namespace llvm;
Expand All @@ -31,7 +32,6 @@ bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
return true;
}


///////////////////////////////////////////////////////////////////////////////
// GCNRegPressure

Expand Down Expand Up @@ -135,8 +135,6 @@ bool GCNRegPressure::less(const GCNSubtarget &ST,
O.getVGPRNum(ST.hasGFX90AInsts()));
}

#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
LLVM_DUMP_METHOD
Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST) {
return Printable([&RP, ST](raw_ostream &OS) {
OS << "VGPRs: " << RP.Value[GCNRegPressure::VGPR32] << ' '
Expand All @@ -155,7 +153,6 @@ Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST) {
OS << '\n';
});
}
#endif

static LaneBitmask getDefRegMask(const MachineOperand &MO,
const MachineRegisterInfo &MRI) {
Expand Down Expand Up @@ -269,6 +266,13 @@ void GCNUpwardRPTracker::reset(const MachineInstr &MI,
GCNRPTracker::reset(MI, LiveRegsCopy, true);
}

void GCNUpwardRPTracker::reset(const MachineRegisterInfo &MRI_,
const LiveRegSet &LiveRegs_) {
MRI = &MRI_;
LiveRegs = LiveRegs_;
MaxPressure = CurPressure = getRegPressure(MRI_, LiveRegs_);
}

void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
assert(MRI && "call reset first");

Expand Down Expand Up @@ -418,27 +422,25 @@ bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
return advance(End);
}

#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
LLVM_DUMP_METHOD
Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
const GCNRPTracker::LiveRegSet &TrackedLR,
const TargetRegisterInfo *TRI) {
return Printable([&LISLR, &TrackedLR, TRI](raw_ostream &OS) {
const TargetRegisterInfo *TRI, StringRef Pfx) {
return Printable([&LISLR, &TrackedLR, TRI, Pfx](raw_ostream &OS) {
for (auto const &P : TrackedLR) {
auto I = LISLR.find(P.first);
if (I == LISLR.end()) {
OS << " " << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
<< " isn't found in LIS reported set\n";
} else if (I->second != P.second) {
OS << " " << printReg(P.first, TRI)
OS << Pfx << printReg(P.first, TRI)
<< " masks doesn't match: LIS reported " << PrintLaneMask(I->second)
<< ", tracked " << PrintLaneMask(P.second) << '\n';
}
}
for (auto const &P : LISLR) {
auto I = TrackedLR.find(P.first);
if (I == TrackedLR.end()) {
OS << " " << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
<< " isn't found in tracked set\n";
}
}
Expand Down Expand Up @@ -467,7 +469,6 @@ bool GCNUpwardRPTracker::isValid() const {
return true;
}

LLVM_DUMP_METHOD
Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs,
const MachineRegisterInfo &MRI) {
return Printable([&LiveRegs, &MRI](raw_ostream &OS) {
Expand All @@ -483,7 +484,122 @@ Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs,
});
}

LLVM_DUMP_METHOD
void GCNRegPressure::dump() const { dbgs() << print(*this); }

#endif
static cl::opt<bool> UseDownwardTracker(
"amdgpu-print-rp-downward",
cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass"),
cl::init(false), cl::Hidden);

char llvm::GCNRegPressurePrinter::ID = 0;
char &llvm::GCNRegPressurePrinterID = GCNRegPressurePrinter::ID;

INITIALIZE_PASS(GCNRegPressurePrinter, "amdgpu-print-rp", "", true, true)

bool GCNRegPressurePrinter::runOnMachineFunction(MachineFunction &MF) {
const MachineRegisterInfo &MRI = MF.getRegInfo();
const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
const LiveIntervals &LIS = getAnalysis<LiveIntervals>();

auto &OS = dbgs();

// Leading spaces are important for YAML syntax.
#define PFX " "

OS << "---\nname: " << MF.getName() << "\nbody: |\n";

auto printRP = [](const GCNRegPressure &RP) {
return Printable([&RP](raw_ostream &OS) {
OS << format(PFX " %-5d", RP.getSGPRNum())
<< format(" %-5d", RP.getVGPRNum(false));
});
};

auto ReportLISMismatchIfAny = [&](const GCNRPTracker::LiveRegSet &TrackedLR,
const GCNRPTracker::LiveRegSet &LISLR) {
if (LISLR != TrackedLR) {
OS << PFX " mis LIS: " << llvm::print(LISLR, MRI)
<< reportMismatch(LISLR, TrackedLR, TRI, PFX " ");
}
};

// Register pressure before and at an instruction (in program order).
SmallVector<std::pair<GCNRegPressure, GCNRegPressure>, 16> RP;

for (auto &MBB : MF) {
RP.clear();
RP.reserve(MBB.size());

OS << PFX;
MBB.printName(OS);
OS << ":\n";

SlotIndex MBBStartSlot = LIS.getSlotIndexes()->getMBBStartIdx(&MBB);
SlotIndex MBBEndSlot = LIS.getSlotIndexes()->getMBBEndIdx(&MBB);

GCNRPTracker::LiveRegSet LiveIn, LiveOut;
GCNRegPressure RPAtMBBEnd;

if (UseDownwardTracker) {
if (MBB.empty()) {
LiveIn = LiveOut = getLiveRegs(MBBStartSlot, LIS, MRI);
RPAtMBBEnd = getRegPressure(MRI, LiveIn);
} else {
GCNDownwardRPTracker RPT(LIS);
RPT.reset(MBB.front());

LiveIn = RPT.getLiveRegs();

while (!RPT.advanceBeforeNext()) {
GCNRegPressure RPBeforeMI = RPT.getPressure();
RPT.advanceToNext();
RP.emplace_back(RPBeforeMI, RPT.getPressure());
}

LiveOut = RPT.getLiveRegs();
RPAtMBBEnd = RPT.getPressure();
}
} else {
GCNUpwardRPTracker RPT(LIS);
RPT.reset(MRI, MBBEndSlot);
RPT.moveMaxPressure(); // Clear max pressure.

LiveOut = RPT.getLiveRegs();
RPAtMBBEnd = RPT.getPressure();

for (auto &MI : reverse(MBB)) {
RPT.recede(MI);
if (!MI.isDebugInstr())
RP.emplace_back(RPT.getPressure(), RPT.moveMaxPressure());
}

LiveIn = RPT.getLiveRegs();
}

OS << PFX " Live-in: " << llvm::print(LiveIn, MRI);
if (!UseDownwardTracker)
ReportLISMismatchIfAny(LiveIn, getLiveRegs(MBBStartSlot, LIS, MRI));

OS << PFX " SGPR VGPR\n";
int I = 0;
for (auto &MI : MBB) {
if (!MI.isDebugInstr()) {
auto &[RPBeforeInstr, RPAtInstr] =
RP[UseDownwardTracker ? I : (RP.size() - 1 - I)];
++I;
OS << printRP(RPBeforeInstr) << '\n' << printRP(RPAtInstr) << " ";
} else
OS << PFX " ";
MI.print(OS);
}
OS << printRP(RPAtMBBEnd) << '\n';

OS << PFX " Live-out:" << llvm::print(LiveOut, MRI);
if (UseDownwardTracker)
ReportLISMismatchIfAny(LiveOut, getLiveRegs(MBBEndSlot, LIS, MRI));
}
OS << "...\n";
return false;

#undef PFX
}
34 changes: 29 additions & 5 deletions llvm/lib/Target/AMDGPU/GCNRegPressure.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ class GCNRPTracker {

void clearMaxPressure() { MaxPressure.clear(); }

GCNRegPressure getPressure() const { return CurPressure; }

// returns MaxPressure, resetting it
decltype(MaxPressure) moveMaxPressure() {
auto Res = MaxPressure;
Expand All @@ -140,6 +142,9 @@ class GCNRPTracker {
}
};

GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS,
const MachineRegisterInfo &MRI);

class GCNUpwardRPTracker : public GCNRPTracker {
public:
GCNUpwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {}
Expand All @@ -148,6 +153,14 @@ class GCNUpwardRPTracker : public GCNRPTracker {
// filling live regs upon this point using LIS
void reset(const MachineInstr &MI, const LiveRegSet *LiveRegs = nullptr);

// reset tracker and set live register set to the specified value.
void reset(const MachineRegisterInfo &MRI_, const LiveRegSet &LiveRegs_);

// reset tracker at the specified slot index.
void reset(const MachineRegisterInfo &MRI_, SlotIndex SI) {
reset(MRI_, llvm::getLiveRegs(SI, LIS, MRI_));
}

// move to the state just above the MI
void recede(const MachineInstr &MI);

Expand Down Expand Up @@ -196,10 +209,6 @@ LaneBitmask getLiveLaneMask(unsigned Reg,
const LiveIntervals &LIS,
const MachineRegisterInfo &MRI);

GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI,
const LiveIntervals &LIS,
const MachineRegisterInfo &MRI);

/// creates a map MachineInstr -> LiveRegSet
/// R - range of iterators on instructions
/// After - upon entry or exit of every instruction
Expand Down Expand Up @@ -275,7 +284,22 @@ Printable print(const GCNRPTracker::LiveRegSet &LiveRegs,

Printable reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
const GCNRPTracker::LiveRegSet &TrackedL,
const TargetRegisterInfo *TRI);
const TargetRegisterInfo *TRI, StringRef Pfx = " ");

struct GCNRegPressurePrinter : public MachineFunctionPass {
static char ID;

public:
GCNRegPressurePrinter() : MachineFunctionPass(ID) {}

bool runOnMachineFunction(MachineFunction &MF) override;

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<LiveIntervals>();
AU.setPreservesAll();
MachineFunctionPass::getAnalysisUsage(AU);
}
};

} // end namespace llvm

Expand Down

0 comments on commit e808f8a

Please sign in to comment.