Skip to content

Commit

Permalink
[AArch64] Relax cross-section branches
Browse files Browse the repository at this point in the history
Because the code layout is not known during compilation, the distance of
cross-section jumps is not knowable at compile-time. Because of this, we
should assume that any cross-sectional jumps are out of range. This
assumption is necessary for machine function splitting on AArch64, which
introduces cross-section branches in the middle of functions. The linker
relaxes out-of-range unconditional branches, but it clobbers X16 to do
so; it doesn't relax conditional branches, which must be manually
relaxed by the compiler.

Differential Revision: https://reviews.llvm.org/D145211
  • Loading branch information
dhoekwater committed Aug 16, 2023
1 parent 8e44f03 commit d7bca8e
Show file tree
Hide file tree
Showing 7 changed files with 338 additions and 7 deletions.
3 changes: 3 additions & 0 deletions llvm/include/llvm/Target/TargetMachine.h
Expand Up @@ -232,6 +232,9 @@ class TargetMachine {
/// target default.
CodeModel::Model getCodeModel() const { return CMModel; }

/// Returns the maximum code size possible under the code model.
uint64_t getMaxCodeSize() const;

/// Set the code model.
void setCodeModel(CodeModel::Model CM) { CMModel = CM; }

Expand Down
45 changes: 39 additions & 6 deletions llvm/lib/CodeGen/BranchRelaxation.cpp
Expand Up @@ -26,6 +26,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include <cassert>
#include <cstdint>
#include <iterator>
Expand Down Expand Up @@ -84,6 +85,7 @@ class BranchRelaxation : public MachineFunctionPass {
MachineFunction *MF = nullptr;
const TargetRegisterInfo *TRI = nullptr;
const TargetInstrInfo *TII = nullptr;
const TargetMachine *TM = nullptr;

bool relaxBranchInstructions();
void scanFunction();
Expand Down Expand Up @@ -232,6 +234,11 @@ BranchRelaxation::createNewBlockAfter(MachineBasicBlock &OrigMBB,
MachineBasicBlock *NewBB = MF->CreateMachineBasicBlock(BB);
MF->insert(++OrigMBB.getIterator(), NewBB);

// Place the new block in the same section as OrigBB
NewBB->setSectionID(OrigMBB.getSectionID());
NewBB->setIsEndSection(OrigMBB.isEndSection());
OrigMBB.setIsEndSection(false);

// Insert an entry into BlockInfo to align it properly with the block numbers.
BlockInfo.insert(BlockInfo.begin() + NewBB->getNumber(), BasicBlockInfo());

Expand All @@ -241,15 +248,21 @@ BranchRelaxation::createNewBlockAfter(MachineBasicBlock &OrigMBB,
/// Split the basic block containing MI into two blocks, which are joined by
/// an unconditional branch. Update data structures and renumber blocks to
/// account for this change and returns the newly created block.
MachineBasicBlock *BranchRelaxation::splitBlockBeforeInstr(MachineInstr &MI,
MachineBasicBlock *DestBB) {
MachineBasicBlock *
BranchRelaxation::splitBlockBeforeInstr(MachineInstr &MI,
MachineBasicBlock *DestBB) {
MachineBasicBlock *OrigBB = MI.getParent();

// Create a new MBB for the code after the OrigBB.
MachineBasicBlock *NewBB =
MF->CreateMachineBasicBlock(OrigBB->getBasicBlock());
MF->insert(++OrigBB->getIterator(), NewBB);

// Place the new block in the same section as OrigBB.
NewBB->setSectionID(OrigBB->getSectionID());
NewBB->setIsEndSection(OrigBB->isEndSection());
OrigBB->setIsEndSection(false);

// Splice the instructions starting with MI over to NewBB.
NewBB->splice(NewBB->end(), OrigBB, MI.getIterator(), OrigBB->end());

Expand Down Expand Up @@ -300,7 +313,12 @@ bool BranchRelaxation::isBlockInRange(
int64_t BrOffset = getInstrOffset(MI);
int64_t DestOffset = BlockInfo[DestBB.getNumber()].Offset;

if (TII->isBranchOffsetInRange(MI.getOpcode(), DestOffset - BrOffset))
const MachineBasicBlock *SrcBB = MI.getParent();

if (TII->isBranchOffsetInRange(MI.getOpcode(),
SrcBB->getSectionID() != DestBB.getSectionID()
? TM->getMaxCodeSize()
: DestOffset - BrOffset))
return true;

LLVM_DEBUG(dbgs() << "Out of range branch to destination "
Expand Down Expand Up @@ -462,7 +480,10 @@ bool BranchRelaxation::fixupUnconditionalBranch(MachineInstr &MI) {
int64_t DestOffset = BlockInfo[DestBB->getNumber()].Offset;
int64_t SrcOffset = getInstrOffset(MI);

assert(!TII->isBranchOffsetInRange(MI.getOpcode(), DestOffset - SrcOffset));
assert(!TII->isBranchOffsetInRange(
MI.getOpcode(), MBB->getSectionID() != DestBB->getSectionID()
? TM->getMaxCodeSize()
: DestOffset - SrcOffset));

BlockInfo[MBB->getNumber()].Size -= OldBrSize;

Expand Down Expand Up @@ -492,9 +513,15 @@ bool BranchRelaxation::fixupUnconditionalBranch(MachineInstr &MI) {
// be erased.
MachineBasicBlock *RestoreBB = createNewBlockAfter(MF->back(),
DestBB->getBasicBlock());
std::prev(RestoreBB->getIterator())
->setIsEndSection(RestoreBB->isEndSection());
RestoreBB->setIsEndSection(false);

TII->insertIndirectBranch(*BranchBB, *DestBB, *RestoreBB, DL,
DestOffset - SrcOffset, RS.get());
BranchBB->getSectionID() != DestBB->getSectionID()
? TM->getMaxCodeSize()
: DestOffset - SrcOffset,
RS.get());

BlockInfo[BranchBB->getNumber()].Size = computeBlockSize(*BranchBB);
adjustBlockOffsets(*MBB);
Expand Down Expand Up @@ -525,6 +552,11 @@ bool BranchRelaxation::fixupUnconditionalBranch(MachineInstr &MI) {
BlockInfo[RestoreBB->getNumber()].Size = computeBlockSize(*RestoreBB);
// Update the offset starting from the previous block.
adjustBlockOffsets(*PrevBB);

// Fix up section information for RestoreBB and DestBB
RestoreBB->setSectionID(DestBB->getSectionID());
RestoreBB->setIsBeginSection(DestBB->isBeginSection());
DestBB->setIsBeginSection(false);
} else {
// Remove restore block if it's not required.
MF->erase(RestoreBB);
Expand Down Expand Up @@ -553,7 +585,7 @@ bool BranchRelaxation::relaxBranchInstructions() {
// Unconditional branch destination might be unanalyzable, assume these
// are OK.
if (MachineBasicBlock *DestBB = TII->getBranchDestBlock(*Last)) {
if (!isBlockInRange(*Last, *DestBB)) {
if (!isBlockInRange(*Last, *DestBB) && !TII->isTailCall(*Last)) {
fixupUnconditionalBranch(*Last);
++NumUnconditionalRelaxed;
Changed = true;
Expand Down Expand Up @@ -607,6 +639,7 @@ bool BranchRelaxation::runOnMachineFunction(MachineFunction &mf) {

const TargetSubtargetInfo &ST = MF->getSubtarget();
TII = ST.getInstrInfo();
TM = &MF->getTarget();

TRI = ST.getRegisterInfo();
if (TRI->trackLivenessAfterRegAlloc(*MF))
Expand Down
69 changes: 68 additions & 1 deletion llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
Expand Up @@ -30,6 +30,7 @@
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/RegisterScavenging.h"
#include "llvm/CodeGen/StackMaps.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
Expand Down Expand Up @@ -70,6 +71,10 @@ static cl::opt<unsigned>
BCCDisplacementBits("aarch64-bcc-offset-bits", cl::Hidden, cl::init(19),
cl::desc("Restrict range of Bcc instructions (DEBUG)"));

static cl::opt<unsigned>
BDisplacementBits("aarch64-b-offset-bits", cl::Hidden, cl::init(26),
cl::desc("Restrict range of B instructions (DEBUG)"));

AArch64InstrInfo::AArch64InstrInfo(const AArch64Subtarget &STI)
: AArch64GenInstrInfo(AArch64::ADJCALLSTACKDOWN, AArch64::ADJCALLSTACKUP,
AArch64::CATCHRET),
Expand Down Expand Up @@ -203,7 +208,7 @@ static unsigned getBranchDisplacementBits(unsigned Opc) {
default:
llvm_unreachable("unexpected opcode!");
case AArch64::B:
return 64;
return BDisplacementBits;
case AArch64::TBNZW:
case AArch64::TBZW:
case AArch64::TBNZX:
Expand Down Expand Up @@ -248,6 +253,68 @@ AArch64InstrInfo::getBranchDestBlock(const MachineInstr &MI) const {
}
}

void AArch64InstrInfo::insertIndirectBranch(MachineBasicBlock &MBB,
MachineBasicBlock &NewDestBB,
MachineBasicBlock &RestoreBB,
const DebugLoc &DL,
int64_t BrOffset,
RegScavenger *RS) const {
assert(RS && "RegScavenger required for long branching");
assert(MBB.empty() &&
"new block should be inserted for expanding unconditional branch");
assert(MBB.pred_size() == 1);
assert(RestoreBB.empty() &&
"restore block should be inserted for restoring clobbered registers");

auto buildIndirectBranch = [&](Register Reg, MachineBasicBlock &DestBB) {
// Offsets outside of the signed 33-bit range are not supported for ADRP +
// ADD.
if (!isInt<33>(BrOffset))
report_fatal_error(
"Branch offsets outside of the signed 33-bit range not supported");

BuildMI(MBB, MBB.end(), DL, get(AArch64::ADRP), Reg)
.addSym(DestBB.getSymbol(), AArch64II::MO_PAGE);
BuildMI(MBB, MBB.end(), DL, get(AArch64::ADDXri), Reg)
.addReg(Reg)
.addSym(DestBB.getSymbol(), AArch64II::MO_PAGEOFF | AArch64II::MO_NC)
.addImm(0);
BuildMI(MBB, MBB.end(), DL, get(AArch64::BR)).addReg(Reg);
};

RS->enterBasicBlockEnd(MBB);
Register Reg = RS->FindUnusedReg(&AArch64::GPR64RegClass);

// If there's a free register, manually insert the indirect branch using it.
if (Reg != AArch64::NoRegister) {
buildIndirectBranch(Reg, NewDestBB);
RS->setRegUsed(Reg);
return;
}

// Otherwise, spill and use X16. This briefly moves the stack pointer, making
// it incompatible with red zones.
AArch64FunctionInfo *AFI = MBB.getParent()->getInfo<AArch64FunctionInfo>();
if (!AFI || AFI->hasRedZone().value_or(true))
report_fatal_error(
"Unable to insert indirect branch inside function that has red zone");

Reg = AArch64::X16;
BuildMI(MBB, MBB.end(), DL, get(AArch64::STRXpre))
.addReg(AArch64::SP, RegState::Define)
.addReg(Reg)
.addReg(AArch64::SP)
.addImm(-16);

buildIndirectBranch(Reg, RestoreBB);

BuildMI(RestoreBB, RestoreBB.end(), DL, get(AArch64::LDRXpost))
.addReg(AArch64::SP, RegState::Define)
.addReg(Reg, RegState::Define)
.addReg(AArch64::SP)
.addImm(16);
}

// Branch analysis.
bool AArch64InstrInfo::analyzeBranch(MachineBasicBlock &MBB,
MachineBasicBlock *&TBB,
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.h
Expand Up @@ -213,6 +213,11 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {

MachineBasicBlock *getBranchDestBlock(const MachineInstr &MI) const override;

void insertIndirectBranch(MachineBasicBlock &MBB,
MachineBasicBlock &NewDestBB,
MachineBasicBlock &RestoreBB, const DebugLoc &DL,
int64_t BrOffset, RegScavenger *RS) const override;

bool analyzeBranch(MachineBasicBlock &MBB, MachineBasicBlock *&TBB,
MachineBasicBlock *&FBB,
SmallVectorImpl<MachineOperand> &Cond,
Expand Down
13 changes: 13 additions & 0 deletions llvm/lib/Target/TargetMachine.cpp
Expand Up @@ -78,6 +78,19 @@ void TargetMachine::resetTargetOptions(const Function &F) const {
/// and dynamic-no-pic.
Reloc::Model TargetMachine::getRelocationModel() const { return RM; }

uint64_t TargetMachine::getMaxCodeSize() const {
switch (getCodeModel()) {
case CodeModel::Tiny:
return llvm::maxUIntN(10);
case CodeModel::Small:
case CodeModel::Kernel:
case CodeModel::Medium:
return llvm::maxUIntN(31);
case CodeModel::Large:
return llvm::maxUIntN(64);
}
}

/// Get the IR-specified TLS model for Var.
static TLSModel::Model getSelectedTLSModel(const GlobalValue *GV) {
switch (GV->getThreadLocalMode()) {
Expand Down

0 comments on commit d7bca8e

Please sign in to comment.