Skip to content

Commit

Permalink
[ModuloSchedule] Add interface call to accept/reject SMS schedules
Browse files Browse the repository at this point in the history
This interface allows a target to reject a proposed
SMS schedule.  For Hexagon/PowerPC, all schedules
are accepted, leaving behavior unchanged.  For ARM,
schedules which exceed register pressure limits are
rejected.

Also, two RegisterPressureTracker methods now need to be public so
that register pressure can be computed by more callers.

Reapplication of D128941/(reversion:D132037) with small fix.

Differential Revision: https://reviews.llvm.org/D132170
  • Loading branch information
dpenry committed Aug 22, 2022
1 parent f82c55f commit ced705c
Show file tree
Hide file tree
Showing 5 changed files with 709 additions and 5 deletions.
10 changes: 5 additions & 5 deletions llvm/include/llvm/CodeGen/RegisterPressure.h
Expand Up @@ -537,6 +537,11 @@ class RegPressureTracker {

void dump() const;

void increaseRegPressure(Register RegUnit, LaneBitmask PreviousMask,
LaneBitmask NewMask);
void decreaseRegPressure(Register RegUnit, LaneBitmask PreviousMask,
LaneBitmask NewMask);

protected:
/// Add Reg to the live out set and increase max pressure.
void discoverLiveOut(RegisterMaskPair Pair);
Expand All @@ -547,11 +552,6 @@ class RegPressureTracker {
/// after the current position.
SlotIndex getCurrSlot() const;

void increaseRegPressure(Register RegUnit, LaneBitmask PreviousMask,
LaneBitmask NewMask);
void decreaseRegPressure(Register RegUnit, LaneBitmask PreviousMask,
LaneBitmask NewMask);

void bumpDeadDefs(ArrayRef<RegisterMaskPair> DeadDefs);

void bumpUpwardPressure(const MachineInstr *MI);
Expand Down
9 changes: 9 additions & 0 deletions llvm/include/llvm/CodeGen/TargetInstrInfo.h
Expand Up @@ -54,6 +54,8 @@ class ScheduleDAGMI;
class ScheduleHazardRecognizer;
class SDNode;
class SelectionDAG;
class SMSchedule;
class SwingSchedulerDAG;
class RegScavenger;
class TargetRegisterClass;
class TargetRegisterInfo;
Expand Down Expand Up @@ -729,6 +731,13 @@ class TargetInstrInfo : public MCInstrInfo {
/// update with no users being pipelined.
virtual bool shouldIgnoreForPipelining(const MachineInstr *MI) const = 0;

/// Return true if the proposed schedule should used. Otherwise return
/// false to not pipeline the loop. This function should be used to ensure
/// that pipelined loops meet target-specific quality heuristics.
virtual bool shouldUseSchedule(SwingSchedulerDAG &SSD, SMSchedule &SMS) {
return true;
}

/// Create a condition to determine if the trip count of the loop is greater
/// than TC, where TC is always one more than for the previous prologue or
/// 0 if this is being called for the outermost prologue.
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/CodeGen/MachinePipeliner.cpp
Expand Up @@ -2098,6 +2098,12 @@ bool SwingSchedulerDAG::schedulePipeline(SMSchedule &Schedule) {
<< " (II=" << Schedule.getInitiationInterval()
<< ")\n");

if (scheduleFound) {
scheduleFound = LoopPipelinerInfo->shouldUseSchedule(*this, Schedule);
if (!scheduleFound)
dbgs() << "Target rejected schedule\n";
}

if (scheduleFound) {
Schedule.finalizeSchedule(this);
Pass.ORE->emit([&]() {
Expand Down
161 changes: 161 additions & 0 deletions llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
Expand Up @@ -25,6 +25,7 @@
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Triple.h"
#include "llvm/CodeGen/DFAPacketizer.h"
#include "llvm/CodeGen/LiveVariables.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineConstantPool.h"
Expand All @@ -35,6 +36,7 @@
#include "llvm/CodeGen/MachineMemOperand.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/CodeGen/MachinePipeliner.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/MachineScheduler.h"
#include "llvm/CodeGen/MultiHazardRecognizer.h"
Expand Down Expand Up @@ -6756,6 +6758,19 @@ class ARMPipelinerLoopInfo : public TargetInstrInfo::PipelinerLoopInfo {
MachineFunction *MF;
const TargetInstrInfo *TII;

// Bitset[0 .. MAX_STAGES-1] ... iterations needed
// [LAST_IS_USE] : last reference to register in schedule is a use
// [SEEN_AS_LIVE] : Normal pressure algorithm believes register is live
static int constexpr MAX_STAGES = 30;
static int constexpr LAST_IS_USE = MAX_STAGES;
static int constexpr SEEN_AS_LIVE = MAX_STAGES + 1;
typedef std::bitset<MAX_STAGES + 2> IterNeed;
typedef std::map<unsigned, IterNeed> IterNeeds;

void bumpCrossIterationPressure(RegPressureTracker &RPT,
const IterNeeds &CIN);
bool tooMuchRegisterPressure(SwingSchedulerDAG &SSD, SMSchedule &SMS);

// Meanings of the various stuff with loop types:
// t2Bcc:
// EndLoop = branch at end of original BB that will become a kernel
Expand All @@ -6774,6 +6789,13 @@ class ARMPipelinerLoopInfo : public TargetInstrInfo::PipelinerLoopInfo {
return MI == EndLoop || MI == LoopCount;
}

bool shouldUseSchedule(SwingSchedulerDAG &SSD, SMSchedule &SMS) override {
if (tooMuchRegisterPressure(SSD, SMS))
return false;

return true;
}

Optional<bool> createTripCountGreaterCondition(
int TC, MachineBasicBlock &MBB,
SmallVectorImpl<MachineOperand> &Cond) override {
Expand Down Expand Up @@ -6812,6 +6834,145 @@ class ARMPipelinerLoopInfo : public TargetInstrInfo::PipelinerLoopInfo {

void disposed() override {}
};

void ARMPipelinerLoopInfo::bumpCrossIterationPressure(RegPressureTracker &RPT,
const IterNeeds &CIN) {
// Increase pressure by the amounts in CrossIterationNeeds
for (const auto &N : CIN) {
int Cnt = N.second.count() - N.second[SEEN_AS_LIVE] * 2;
for (int I = 0; I < Cnt; ++I)
RPT.increaseRegPressure(Register(N.first), LaneBitmask::getNone(),
LaneBitmask::getAll());
}
// Decrease pressure by the amounts in CrossIterationNeeds
for (const auto &N : CIN) {
int Cnt = N.second.count() - N.second[SEEN_AS_LIVE] * 2;
for (int I = 0; I < Cnt; ++I)
RPT.decreaseRegPressure(Register(N.first), LaneBitmask::getAll(),
LaneBitmask::getNone());
}
}

bool ARMPipelinerLoopInfo::tooMuchRegisterPressure(SwingSchedulerDAG &SSD,
SMSchedule &SMS) {
IterNeeds CrossIterationNeeds;

// Determine which values will be loop-carried after the schedule is
// applied

for (auto &SU : SSD.SUnits) {
const MachineInstr *MI = SU.getInstr();
int Stg = SMS.stageScheduled(const_cast<SUnit *>(&SU));
for (auto &S : SU.Succs)
if (MI->isPHI() && S.getKind() == SDep::Anti) {
Register Reg = S.getReg();
if (Register::isVirtualRegister(Reg))
CrossIterationNeeds.insert(std::make_pair(Reg.id(), IterNeed()))
.first->second.set(0);
} else if (S.isAssignedRegDep()) {
int OStg = SMS.stageScheduled(S.getSUnit());
if (OStg >= 0 && OStg != Stg) {
Register Reg = S.getReg();
if (Register::isVirtualRegister(Reg))
CrossIterationNeeds.insert(std::make_pair(Reg.id(), IterNeed()))
.first->second |= ((1 << (OStg - Stg)) - 1);
}
}
}

// Determine more-or-less what the proposed schedule (reversed) is going to
// be; it might not be quite the same because the within-cycle ordering
// created by SMSchedule depends upon changes to help with address offsets and
// the like.
std::vector<SUnit *> ProposedSchedule;
for (int Cycle = SMS.getFinalCycle(); Cycle >= SMS.getFirstCycle(); --Cycle)
for (int Stage = 0, StageEnd = SMS.getMaxStageCount(); Stage <= StageEnd;
++Stage) {
std::deque<SUnit *> Instrs =
SMS.getInstructions(Cycle + Stage * SMS.getInitiationInterval());
std::sort(Instrs.begin(), Instrs.end(),
[](SUnit *A, SUnit *B) { return A->NodeNum > B->NodeNum; });
for (SUnit *SU : Instrs)
ProposedSchedule.push_back(SU);
}

// Learn whether the last use/def of each cross-iteration register is a use or
// def. If it is a def, RegisterPressure will implicitly increase max pressure
// and we do not have to add the pressure.
for (auto SU : ProposedSchedule)
for (ConstMIBundleOperands OperI(*SU->getInstr()); OperI.isValid();
++OperI) {
auto MO = *OperI;
if (!MO.isReg() || !MO.getReg())
continue;
Register Reg = MO.getReg();
auto CIter = CrossIterationNeeds.find(Reg.id());
if (CIter == CrossIterationNeeds.end() || CIter->second[LAST_IS_USE] ||
CIter->second[SEEN_AS_LIVE])
continue;
if (MO.isDef() && !MO.isDead())
CIter->second.set(SEEN_AS_LIVE);
else if (MO.isUse())
CIter->second.set(LAST_IS_USE);
}
for (auto &CI : CrossIterationNeeds)
CI.second.reset(LAST_IS_USE);

RegionPressure RecRegPressure;
RegPressureTracker RPTracker(RecRegPressure);
RegisterClassInfo RegClassInfo;
RegClassInfo.runOnMachineFunction(*MF);
RPTracker.init(MF, &RegClassInfo, nullptr, EndLoop->getParent(),
EndLoop->getParent()->end(), false, false);
const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();

bumpCrossIterationPressure(RPTracker, CrossIterationNeeds);

for (auto SU : ProposedSchedule) {
MachineBasicBlock::const_iterator CurInstI = SU->getInstr();
RPTracker.setPos(std::next(CurInstI));
RPTracker.recede();

// Track what cross-iteration registers would be seen as live
for (ConstMIBundleOperands OperI(*CurInstI); OperI.isValid(); ++OperI) {
auto MO = *OperI;
if (!MO.isReg() || !MO.getReg())
continue;
Register Reg = MO.getReg();
if (MO.isDef() && !MO.isDead()) {
auto CIter = CrossIterationNeeds.find(Reg.id());
if (CIter != CrossIterationNeeds.end()) {
CIter->second.reset(0);
CIter->second.reset(SEEN_AS_LIVE);
}
}
}
for (auto &S : SU->Preds) {
auto Stg = SMS.stageScheduled(SU);
if (S.isAssignedRegDep()) {
Register Reg = S.getReg();
auto CIter = CrossIterationNeeds.find(Reg.id());
if (CIter != CrossIterationNeeds.end()) {
auto Stg2 = SMS.stageScheduled(const_cast<SUnit *>(S.getSUnit()));
assert(Stg2 <= Stg && "Data dependence upon earlier stage");
if (Stg - Stg2 < MAX_STAGES)
CIter->second.set(Stg - Stg2);
CIter->second.set(SEEN_AS_LIVE);
}
}
}

bumpCrossIterationPressure(RPTracker, CrossIterationNeeds);
}

auto &P = RPTracker.getPressure().MaxSetPressure;
for (unsigned I = 0, E = P.size(); I < E; ++I)
if (P[I] > TRI->getRegPressureSetLimit(*MF, I)) {
return true;
}
return false;
}

} // namespace

std::unique_ptr<TargetInstrInfo::PipelinerLoopInfo>
Expand Down

0 comments on commit ced705c

Please sign in to comment.