Skip to content

Commit

Permalink
[CodeGen][NewPM] Port machine-branch-prob to new pass manager (#96389)
Browse files Browse the repository at this point in the history
Like IR version `print<branch-prob>`, there is also a
`print<machine-branch-prob>`.
  • Loading branch information
paperchalice committed Jun 27, 2024
1 parent 73e6f9f commit 73e46c2
Show file tree
Hide file tree
Showing 24 changed files with 154 additions and 73 deletions.
53 changes: 43 additions & 10 deletions llvm/include/llvm/CodeGen/MachineBranchProbabilityInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@
#define LLVM_CODEGEN_MACHINEBRANCHPROBABILITYINFO_H

#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachinePassManager.h"
#include "llvm/Pass.h"
#include "llvm/Support/BranchProbability.h"

namespace llvm {

class MachineBranchProbabilityInfo : public ImmutablePass {
virtual void anchor();

class MachineBranchProbabilityInfo {
// Default weight value. Used when we don't have information about the edge.
// TODO: DEFAULT_WEIGHT makes sense during static predication, when none of
// the successors have a weight yet. But it doesn't make sense when providing
Expand All @@ -31,13 +30,8 @@ class MachineBranchProbabilityInfo : public ImmutablePass {
static const uint32_t DEFAULT_WEIGHT = 16;

public:
static char ID;

MachineBranchProbabilityInfo();

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesAll();
}
bool invalidate(MachineFunction &, const PreservedAnalyses &PA,
MachineFunctionAnalysisManager::Invalidator &);

// Return edge probability.
BranchProbability getEdgeProbability(const MachineBasicBlock *Src,
Expand All @@ -61,6 +55,45 @@ class MachineBranchProbabilityInfo : public ImmutablePass {
const MachineBasicBlock *Dst) const;
};

class MachineBranchProbabilityAnalysis
: public AnalysisInfoMixin<MachineBranchProbabilityAnalysis> {
friend AnalysisInfoMixin<MachineBranchProbabilityAnalysis>;

static AnalysisKey Key;

public:
using Result = MachineBranchProbabilityInfo;

Result run(MachineFunction &, MachineFunctionAnalysisManager &);
};

class MachineBranchProbabilityPrinterPass
: public PassInfoMixin<MachineBranchProbabilityPrinterPass> {
raw_ostream &OS;

public:
MachineBranchProbabilityPrinterPass(raw_ostream &OS) : OS(OS) {}
PreservedAnalyses run(MachineFunction &MF,
MachineFunctionAnalysisManager &MFAM);
};

class MachineBranchProbabilityInfoWrapperPass : public ImmutablePass {
virtual void anchor();

MachineBranchProbabilityInfo MBPI;

public:
static char ID;

MachineBranchProbabilityInfoWrapperPass();

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesAll();
}

MachineBranchProbabilityInfo &getMBPI() { return MBPI; }
const MachineBranchProbabilityInfo &getMBPI() const { return MBPI; }
};
}


Expand Down
2 changes: 1 addition & 1 deletion llvm/include/llvm/InitializePasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ void initializeMIRPrintingPassPass(PassRegistry&);
void initializeMachineBlockFrequencyInfoPass(PassRegistry&);
void initializeMachineBlockPlacementPass(PassRegistry&);
void initializeMachineBlockPlacementStatsPass(PassRegistry&);
void initializeMachineBranchProbabilityInfoPass(PassRegistry&);
void initializeMachineBranchProbabilityInfoWrapperPassPass(PassRegistry &);
void initializeMachineCFGPrinterPass(PassRegistry &);
void initializeMachineCSEPass(PassRegistry&);
void initializeMachineCombinerPass(PassRegistry&);
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/Passes/MachinePassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ LOOP_PASS("loop-reduce", LoopStrengthReducePass())
#ifndef MACHINE_FUNCTION_ANALYSIS
#define MACHINE_FUNCTION_ANALYSIS(NAME, CREATE_PASS)
#endif
MACHINE_FUNCTION_ANALYSIS("machine-branch-prob",
MachineBranchProbabilityAnalysis())
MACHINE_FUNCTION_ANALYSIS("machine-dom-tree", MachineDominatorTreeAnalysis())
MACHINE_FUNCTION_ANALYSIS("machine-post-dom-tree",
MachinePostDominatorTreeAnalysis())
Expand Down Expand Up @@ -130,6 +132,8 @@ MACHINE_FUNCTION_PASS("finalize-isel", FinalizeISelPass())
MACHINE_FUNCTION_PASS("localstackalloc", LocalStackSlotAllocationPass())
MACHINE_FUNCTION_PASS("no-op-machine-function", NoOpMachineFunctionPass())
MACHINE_FUNCTION_PASS("print", PrintMIRPass())
MACHINE_FUNCTION_PASS("print<machine-branch-prob>",
MachineBranchProbabilityPrinterPass(dbgs()))
MACHINE_FUNCTION_PASS("print<machine-dom-tree>",
MachineDominatorTreePrinterPass(dbgs()))
MACHINE_FUNCTION_PASS("print<machine-post-dom-tree>",
Expand Down
7 changes: 4 additions & 3 deletions llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ void AsmPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<MachineOptimizationRemarkEmitterPass>();
AU.addRequired<GCModuleInfo>();
AU.addRequired<LazyMachineBlockFrequencyInfoPass>();
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
}

bool AsmPrinter::doInitialization(Module &M) {
Expand Down Expand Up @@ -1478,8 +1478,9 @@ void AsmPrinter::emitBBAddrMapSection(const MachineFunction &MF) {
? &getAnalysis<LazyMachineBlockFrequencyInfoPass>().getBFI()
: nullptr;
const MachineBranchProbabilityInfo *MBPI =
Features.BrProb ? &getAnalysis<MachineBranchProbabilityInfo>()
: nullptr;
Features.BrProb
? &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI()
: nullptr;

if (Features.BBFreq || Features.BrProb) {
for (const MachineBasicBlock &MBB : MF) {
Expand Down
9 changes: 5 additions & 4 deletions llvm/lib/CodeGen/BranchFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ namespace {

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<MachineBlockFrequencyInfo>();
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
AU.addRequired<ProfileSummaryInfoWrapperPass>();
AU.addRequired<TargetPassConfig>();
MachineFunctionPass::getAnalysisUsage(AU);
Expand Down Expand Up @@ -131,9 +131,10 @@ bool BranchFolderPass::runOnMachineFunction(MachineFunction &MF) {
PassConfig->getEnableTailMerge();
MBFIWrapper MBBFreqInfo(
getAnalysis<MachineBlockFrequencyInfo>());
BranchFolder Folder(EnableTailMerge, /*CommonHoist=*/true, MBBFreqInfo,
getAnalysis<MachineBranchProbabilityInfo>(),
&getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI());
BranchFolder Folder(
EnableTailMerge, /*CommonHoist=*/true, MBBFreqInfo,
getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI(),
&getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI());
return Folder.OptimizeFunction(MF, MF.getSubtarget().getInstrInfo(),
MF.getSubtarget().getRegisterInfo());
}
Expand Down
10 changes: 5 additions & 5 deletions llvm/lib/CodeGen/EarlyIfConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,14 +787,14 @@ char &llvm::EarlyIfConverterID = EarlyIfConverter::ID;

INITIALIZE_PASS_BEGIN(EarlyIfConverter, DEBUG_TYPE,
"Early If Converter", false, false)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineTraceMetrics)
INITIALIZE_PASS_END(EarlyIfConverter, DEBUG_TYPE,
"Early If Converter", false, false)

void EarlyIfConverter::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
AU.addRequired<MachineDominatorTreeWrapperPass>();
AU.addPreserved<MachineDominatorTreeWrapperPass>();
AU.addRequired<MachineLoopInfo>();
Expand Down Expand Up @@ -1142,12 +1142,12 @@ char &llvm::EarlyIfPredicatorID = EarlyIfPredicator::ID;
INITIALIZE_PASS_BEGIN(EarlyIfPredicator, DEBUG_TYPE, "Early If Predicator",
false, false)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_END(EarlyIfPredicator, DEBUG_TYPE, "Early If Predicator", false,
false)

void EarlyIfPredicator::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
AU.addRequired<MachineDominatorTreeWrapperPass>();
AU.addPreserved<MachineDominatorTreeWrapperPass>();
AU.addRequired<MachineLoopInfo>();
Expand Down Expand Up @@ -1222,7 +1222,7 @@ bool EarlyIfPredicator::runOnMachineFunction(MachineFunction &MF) {
SchedModel.init(&STI);
DomTree = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
Loops = &getAnalysis<MachineLoopInfo>();
MBPI = &getAnalysis<MachineBranchProbabilityInfo>();
MBPI = &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();

bool Changed = false;
IfConv.runOnMachineFunction(MF);
Expand Down
10 changes: 6 additions & 4 deletions llvm/lib/CodeGen/GlobalISel/RegBankSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ INITIALIZE_PASS_BEGIN(RegBankSelect, DEBUG_TYPE,
"Assign register bank of generic virtual registers",
false, false);
INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE,
"Assign register bank of generic virtual registers", false,
Expand All @@ -86,7 +86,7 @@ void RegBankSelect::init(MachineFunction &MF) {
TPC = &getAnalysis<TargetPassConfig>();
if (OptMode != Mode::Fast) {
MBFI = &getAnalysis<MachineBlockFrequencyInfo>();
MBPI = &getAnalysis<MachineBranchProbabilityInfo>();
MBPI = &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
} else {
MBFI = nullptr;
MBPI = nullptr;
Expand All @@ -100,7 +100,7 @@ void RegBankSelect::getAnalysisUsage(AnalysisUsage &AU) const {
// We could preserve the information from these two analysis but
// the APIs do not allow to do so yet.
AU.addRequired<MachineBlockFrequencyInfo>();
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
}
AU.addRequired<TargetPassConfig>();
getSelectionDAGFallbackAnalysisUsage(AU);
Expand Down Expand Up @@ -955,8 +955,10 @@ uint64_t RegBankSelect::EdgeInsertPoint::frequency(const Pass &P) const {
if (WasMaterialized)
return MBFI->getBlockFreq(DstOrSplit).getFrequency();

auto *MBPIWrapper =
P.getAnalysisIfAvailable<MachineBranchProbabilityInfoWrapperPass>();
const MachineBranchProbabilityInfo *MBPI =
P.getAnalysisIfAvailable<MachineBranchProbabilityInfo>();
MBPIWrapper ? &MBPIWrapper->getMBPI() : nullptr;
if (!MBPI)
return 1;
// The basic block will be on the edge.
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/CodeGen/IfConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ namespace {

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<MachineBlockFrequencyInfo>();
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
AU.addRequired<ProfileSummaryInfoWrapperPass>();
MachineFunctionPass::getAnalysisUsage(AU);
}
Expand Down Expand Up @@ -432,7 +432,7 @@ char IfConverter::ID = 0;
char &llvm::IfConverterID = IfConverter::ID;

INITIALIZE_PASS_BEGIN(IfConverter, DEBUG_TYPE, "If Converter", false, false)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass)
INITIALIZE_PASS_END(IfConverter, DEBUG_TYPE, "If Converter", false, false)

Expand All @@ -445,7 +445,7 @@ bool IfConverter::runOnMachineFunction(MachineFunction &MF) {
TII = ST.getInstrInfo();
TRI = ST.getRegisterInfo();
MBFIWrapper MBFI(getAnalysis<MachineBlockFrequencyInfo>());
MBPI = &getAnalysis<MachineBranchProbabilityInfo>();
MBPI = &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
ProfileSummaryInfo *PSI =
&getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
MRI = &MF.getRegInfo();
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/CodeGen/LazyMachineBlockFrequencyInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using namespace llvm;

INITIALIZE_PASS_BEGIN(LazyMachineBlockFrequencyInfoPass, DEBUG_TYPE,
"Lazy Machine Block Frequency Analysis", true, true)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
INITIALIZE_PASS_END(LazyMachineBlockFrequencyInfoPass, DEBUG_TYPE,
"Lazy Machine Block Frequency Analysis", true, true)
Expand All @@ -43,7 +43,7 @@ void LazyMachineBlockFrequencyInfoPass::print(raw_ostream &OS,

void LazyMachineBlockFrequencyInfoPass::getAnalysisUsage(
AnalysisUsage &AU) const {
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
AU.setPreservesAll();
MachineFunctionPass::getAnalysisUsage(AU);
}
Expand All @@ -62,7 +62,7 @@ LazyMachineBlockFrequencyInfoPass::calculateIfNotAvailable() const {
return *MBFI;
}

auto &MBPI = getAnalysis<MachineBranchProbabilityInfo>();
auto &MBPI = getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
auto *MLI = getAnalysisIfAvailable<MachineLoopInfo>();
auto *MDTWrapper = getAnalysisIfAvailable<MachineDominatorTreeWrapperPass>();
auto *MDT = MDTWrapper ? &MDTWrapper->getDomTree() : nullptr;
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/CodeGen/MachineBlockFrequencyInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ struct DOTGraphTraits<MachineBlockFrequencyInfo *>

INITIALIZE_PASS_BEGIN(MachineBlockFrequencyInfo, DEBUG_TYPE,
"Machine Block Frequency Analysis", true, true)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
INITIALIZE_PASS_END(MachineBlockFrequencyInfo, DEBUG_TYPE,
"Machine Block Frequency Analysis", true, true)
Expand All @@ -185,7 +185,7 @@ MachineBlockFrequencyInfo::MachineBlockFrequencyInfo(
MachineBlockFrequencyInfo::~MachineBlockFrequencyInfo() = default;

void MachineBlockFrequencyInfo::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
AU.addRequired<MachineLoopInfo>();
AU.setPreservesAll();
MachineFunctionPass::getAnalysisUsage(AU);
Expand All @@ -209,7 +209,7 @@ void MachineBlockFrequencyInfo::calculate(

bool MachineBlockFrequencyInfo::runOnMachineFunction(MachineFunction &F) {
MachineBranchProbabilityInfo &MBPI =
getAnalysis<MachineBranchProbabilityInfo>();
getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
MachineLoopInfo &MLI = getAnalysis<MachineLoopInfo>();
calculate(F, MBPI, MLI);
return false;
Expand Down
12 changes: 6 additions & 6 deletions llvm/lib/CodeGen/MachineBlockPlacement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ class MachineBlockPlacement : public MachineFunctionPass {
}

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
AU.addRequired<MachineBlockFrequencyInfo>();
if (TailDupPlacement)
AU.addRequired<MachinePostDominatorTreeWrapperPass>();
Expand All @@ -627,7 +627,7 @@ char &llvm::MachineBlockPlacementID = MachineBlockPlacement::ID;

INITIALIZE_PASS_BEGIN(MachineBlockPlacement, DEBUG_TYPE,
"Branch Probability Basic Block Placement", false, false)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
Expand Down Expand Up @@ -3425,7 +3425,7 @@ bool MachineBlockPlacement::runOnMachineFunction(MachineFunction &MF) {
return false;

F = &MF;
MBPI = &getAnalysis<MachineBranchProbabilityInfo>();
MBPI = &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
MBFI = std::make_unique<MBFIWrapper>(
getAnalysis<MachineBlockFrequencyInfo>());
MLI = &getAnalysis<MachineLoopInfo>();
Expand Down Expand Up @@ -3726,7 +3726,7 @@ class MachineBlockPlacementStats : public MachineFunctionPass {
bool runOnMachineFunction(MachineFunction &F) override;

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
AU.addRequired<MachineBlockFrequencyInfo>();
AU.setPreservesAll();
MachineFunctionPass::getAnalysisUsage(AU);
Expand All @@ -3741,7 +3741,7 @@ char &llvm::MachineBlockPlacementStatsID = MachineBlockPlacementStats::ID;

INITIALIZE_PASS_BEGIN(MachineBlockPlacementStats, "block-placement-stats",
"Basic Block Placement Stats", false, false)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
INITIALIZE_PASS_END(MachineBlockPlacementStats, "block-placement-stats",
"Basic Block Placement Stats", false, false)
Expand All @@ -3754,7 +3754,7 @@ bool MachineBlockPlacementStats::runOnMachineFunction(MachineFunction &F) {
if (!isFunctionInPrintList(F.getName()))
return false;

MBPI = &getAnalysis<MachineBranchProbabilityInfo>();
MBPI = &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
MBFI = &getAnalysis<MachineBlockFrequencyInfo>();

for (MachineBasicBlock &MBB : F) {
Expand Down
Loading

0 comments on commit 73e46c2

Please sign in to comment.