Skip to content

Commit

Permalink
[ARM] Introduce t2DoLoopStartTP
Browse files Browse the repository at this point in the history
This introduces a new pseudo instruction, almost identical to a
t2DoLoopStart but taking 2 parameters - the original loop iteration
count needed for a low overhead loop, plus the VCTP element count needed
for a DLSTP instruction setting up a tail predicated loop. The idea is
that the instruction holds both values and the backend
ARMLowOverheadLoops pass can pick between the two, depending on whether
it creates a tail predicated loop or falls back to a low overhead loop.

To do that there needs to be something that converts a t2DoLoopStart to
a t2DoLoopStartTP, for which this patch repurposes the
MVEVPTOptimisationsPass as a "tail predication and vpt optimisation"
pass. The extra operand for the t2DoLoopStartTP is chosen based on the
operands of VCTP's in the loop, and the instruction is moved as late in
the block as possible to attempt to increase the likelihood of making
tail predicated loops.

Differential Revision: https://reviews.llvm.org/D90591
  • Loading branch information
davemgreen committed Nov 10, 2020
1 parent 02af110 commit 08d1c2d
Show file tree
Hide file tree
Showing 14 changed files with 390 additions and 237 deletions.
4 changes: 2 additions & 2 deletions llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
Expand Up @@ -5949,8 +5949,8 @@ ARMBaseInstrInfo::getOutliningType(MachineBasicBlock::iterator &MIT,

// Be conservative with ARMv8.1 MVE instructions.
if (Opc == ARM::t2BF_LabelPseudo || Opc == ARM::t2DoLoopStart ||
Opc == ARM::t2WhileLoopStart || Opc == ARM::t2LoopDec ||
Opc == ARM::t2LoopEnd)
Opc == ARM::t2DoLoopStartTP || Opc == ARM::t2WhileLoopStart ||
Opc == ARM::t2LoopDec || Opc == ARM::t2LoopEnd)
return outliner::InstrType::Illegal;

const MCInstrDesc &MCID = MI.getDesc();
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/ARM/ARMBaseInstrInfo.h
Expand Up @@ -660,6 +660,7 @@ static inline bool isVCTP(const MachineInstr *MI) {
static inline
bool isLoopStart(MachineInstr &MI) {
return MI.getOpcode() == ARM::t2DoLoopStart ||
MI.getOpcode() == ARM::t2DoLoopStartTP ||
MI.getOpcode() == ARM::t2WhileLoopStart;
}

Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/ARM/ARMInstrThumb2.td
Expand Up @@ -5427,6 +5427,9 @@ def t2DoLoopStart :
t2PseudoInst<(outs GPRlr:$X), (ins rGPR:$elts), 4, IIC_Br,
[(set GPRlr:$X, (int_start_loop_iterations rGPR:$elts))]>;

def t2DoLoopStartTP :
t2PseudoInst<(outs GPRlr:$X), (ins rGPR:$elts, rGPR:$count), 4, IIC_Br, []>;

let hasSideEffects = 0 in
def t2LoopDec :
t2PseudoInst<(outs GPRlr:$Rm), (ins GPRlr:$Rn, imm0_7:$size),
Expand Down
176 changes: 91 additions & 85 deletions llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp
Expand Up @@ -101,6 +101,10 @@ static bool shouldInspect(MachineInstr &MI) {
hasVPRUse(&MI);
}

static bool isDo(MachineInstr *MI) {
return MI->getOpcode() != ARM::t2WhileLoopStart;
}

namespace {

using InstSet = SmallPtrSetImpl<MachineInstr *>;
Expand Down Expand Up @@ -431,12 +435,11 @@ namespace {
MachineOperand &getLoopStartOperand() {
if (IsTailPredicationLegal())
return TPNumElements;
return Start->getOpcode() == ARM::t2DoLoopStart ? Start->getOperand(1)
: Start->getOperand(0);
return isDo(Start) ? Start->getOperand(1) : Start->getOperand(0);
}

unsigned getStartOpcode() const {
bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
bool IsDo = isDo(Start);
if (!IsTailPredicationLegal())
return IsDo ? ARM::t2DLS : ARM::t2WLS;

Expand Down Expand Up @@ -622,12 +625,10 @@ bool LowOverheadLoop::ValidateTailPredicate() {
// count instead of iteration count, won't affect any other instructions
// than the LoopStart and LoopDec.
// TODO: We should try to insert the [W|D]LSTP after any of the other uses.
Register StartReg = Start->getOpcode() == ARM::t2DoLoopStart
? Start->getOperand(1).getReg()
: Start->getOperand(0).getReg();
Register StartReg = isDo(Start) ? Start->getOperand(1).getReg()
: Start->getOperand(0).getReg();
if (StartInsertPt == Start && StartReg == ARM::LR) {
if (auto *IterCount = RDA.getMIOperand(
Start, Start->getOpcode() == ARM::t2DoLoopStart ? 1 : 0)) {
if (auto *IterCount = RDA.getMIOperand(Start, isDo(Start) ? 1 : 0)) {
SmallPtrSet<MachineInstr *, 2> Uses;
RDA.getGlobalUses(IterCount, MCRegister::from(ARM::LR), Uses);
for (auto *Use : Uses) {
Expand All @@ -644,53 +645,88 @@ bool LowOverheadLoop::ValidateTailPredicate() {
// elements is provided to the vctp instruction, so we need to check that
// we can use this register at InsertPt.
MachineInstr *VCTP = VCTPs.back();
TPNumElements = VCTP->getOperand(1);
MCRegister NumElements = TPNumElements.getReg().asMCReg();

// If the register is defined within loop, then we can't perform TP.
// TODO: Check whether this is just a mov of a register that would be
// available.
if (RDA.hasLocalDefBefore(VCTP, NumElements)) {
LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n");
return false;
}
if (Start->getOpcode() == ARM::t2DoLoopStartTP) {
TPNumElements = Start->getOperand(2);
StartInsertPt = Start;
StartInsertBB = Start->getParent();
} else {
TPNumElements = VCTP->getOperand(1);
MCRegister NumElements = TPNumElements.getReg().asMCReg();

// If the register is defined within loop, then we can't perform TP.
// TODO: Check whether this is just a mov of a register that would be
// available.
if (RDA.hasLocalDefBefore(VCTP, NumElements)) {
LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n");
return false;
}

// The element count register maybe defined after InsertPt, in which case we
// need to try to move either InsertPt or the def so that the [w|d]lstp can
// use the value.

if (StartInsertPt != StartInsertBB->end() &&
!RDA.isReachingDefLiveOut(&*StartInsertPt, NumElements)) {
if (auto *ElemDef = RDA.getLocalLiveOutMIDef(StartInsertBB, NumElements)) {
if (RDA.isSafeToMoveForwards(ElemDef, &*StartInsertPt)) {
ElemDef->removeFromParent();
StartInsertBB->insert(StartInsertPt, ElemDef);
LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: "
<< *ElemDef);
} else if (RDA.isSafeToMoveBackwards(&*StartInsertPt, ElemDef)) {
StartInsertPt->removeFromParent();
StartInsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef),
&*StartInsertPt);
LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
} else {
// If we fail to move an instruction and the element count is provided
// by a mov, use the mov operand if it will have the same value at the
// insertion point
MachineOperand Operand = ElemDef->getOperand(1);
if (isMovRegOpcode(ElemDef->getOpcode()) &&
RDA.getUniqueReachingMIDef(ElemDef, Operand.getReg().asMCReg()) ==
RDA.getUniqueReachingMIDef(&*StartInsertPt,
Operand.getReg().asMCReg())) {
TPNumElements = Operand;
NumElements = TPNumElements.getReg();
} else {
// The element count register maybe defined after InsertPt, in which case we
// need to try to move either InsertPt or the def so that the [w|d]lstp can
// use the value.

if (StartInsertPt != StartInsertBB->end() &&
!RDA.isReachingDefLiveOut(&*StartInsertPt, NumElements)) {
if (auto *ElemDef =
RDA.getLocalLiveOutMIDef(StartInsertBB, NumElements)) {
if (RDA.isSafeToMoveForwards(ElemDef, &*StartInsertPt)) {
ElemDef->removeFromParent();
StartInsertBB->insert(StartInsertPt, ElemDef);
LLVM_DEBUG(dbgs()
<< "ARM Loops: Unable to move element count to loop "
<< "start instruction.\n");
return false;
<< "ARM Loops: Moved element count def: " << *ElemDef);
} else if (RDA.isSafeToMoveBackwards(&*StartInsertPt, ElemDef)) {
StartInsertPt->removeFromParent();
StartInsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef),
&*StartInsertPt);
LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
} else {
// If we fail to move an instruction and the element count is provided
// by a mov, use the mov operand if it will have the same value at the
// insertion point
MachineOperand Operand = ElemDef->getOperand(1);
if (isMovRegOpcode(ElemDef->getOpcode()) &&
RDA.getUniqueReachingMIDef(ElemDef, Operand.getReg().asMCReg()) ==
RDA.getUniqueReachingMIDef(&*StartInsertPt,
Operand.getReg().asMCReg())) {
TPNumElements = Operand;
NumElements = TPNumElements.getReg();
} else {
LLVM_DEBUG(dbgs()
<< "ARM Loops: Unable to move element count to loop "
<< "start instruction.\n");
return false;
}
}
}
}

// Especially in the case of while loops, InsertBB may not be the
// preheader, so we need to check that the register isn't redefined
// before entering the loop.
auto CannotProvideElements = [this](MachineBasicBlock *MBB,
MCRegister NumElements) {
if (MBB->empty())
return false;
// NumElements is redefined in this block.
if (RDA.hasLocalDefBefore(&MBB->back(), NumElements))
return true;

// Don't continue searching up through multiple predecessors.
if (MBB->pred_size() > 1)
return true;

return false;
};

// Search backwards for a def, until we get to InsertBB.
MachineBasicBlock *MBB = Preheader;
while (MBB && MBB != StartInsertBB) {
if (CannotProvideElements(MBB, NumElements)) {
LLVM_DEBUG(dbgs() << "ARM Loops: Unable to provide element count.\n");
return false;
}
MBB = *MBB->pred_begin();
}
}

// Could inserting the [W|D]LSTP cause some unintended affects? In a perfect
Expand All @@ -717,34 +753,6 @@ bool LowOverheadLoop::ValidateTailPredicate() {
if (CannotInsertWDLSTPBetween(StartInsertPt, StartInsertBB->end()))
return false;

// Especially in the case of while loops, InsertBB may not be the
// preheader, so we need to check that the register isn't redefined
// before entering the loop.
auto CannotProvideElements = [this](MachineBasicBlock *MBB,
MCRegister NumElements) {
if (MBB->empty())
return false;
// NumElements is redefined in this block.
if (RDA.hasLocalDefBefore(&MBB->back(), NumElements))
return true;

// Don't continue searching up through multiple predecessors.
if (MBB->pred_size() > 1)
return true;

return false;
};

// Search backwards for a def, until we get to InsertBB.
MachineBasicBlock *MBB = Preheader;
while (MBB && MBB != StartInsertBB) {
if (CannotProvideElements(MBB, NumElements)) {
LLVM_DEBUG(dbgs() << "ARM Loops: Unable to provide element count.\n");
return false;
}
MBB = *MBB->pred_begin();
}

// Check that the value change of the element count is what we expect and
// that the predication will be equivalent. For this we need:
// NumElements = NumElements - VectorWidth. The sub will be a sub immediate
Expand All @@ -753,7 +761,7 @@ bool LowOverheadLoop::ValidateTailPredicate() {
return -getAddSubImmediate(*MI) == ExpectedVecWidth;
};

MBB = VCTP->getParent();
MachineBasicBlock *MBB = VCTP->getParent();
// Remove modifications to the element count since they have no purpose in a
// tail predicated loop. Explicitly refer to the vctp operand no matter which
// register NumElements has been assigned to, since that is what the
Expand Down Expand Up @@ -1062,8 +1070,7 @@ void LowOverheadLoop::Validate(ARMBasicBlockUtils *BBUtils) {
InstSet &ToRemove) {
// For a t2DoLoopStart it is always valid to use the start insertion point.
// For WLS we can define LR if LR already contains the same value.
if (Start->getOpcode() == ARM::t2DoLoopStart ||
Start->getOperand(0).getReg() == ARM::LR) {
if (isDo(Start) || Start->getOperand(0).getReg() == ARM::LR) {
InsertPt = MachineBasicBlock::iterator(Start);
InsertBB = Start->getParent();
return true;
Expand Down Expand Up @@ -1434,8 +1441,8 @@ void ARMLowOverheadLoops::IterationCountDCE(LowOverheadLoop &LoLoop) {

LLVM_DEBUG(dbgs() << "ARM Loops: Trying DCE on loop iteration count.\n");

MachineInstr *Def = RDA->getMIOperand(
LoLoop.Start, LoLoop.Start->getOpcode() == ARM::t2DoLoopStart ? 1 : 0);
MachineInstr *Def =
RDA->getMIOperand(LoLoop.Start, isDo(LoLoop.Start) ? 1 : 0);
if (!Def) {
LLVM_DEBUG(dbgs() << "ARM Loops: Couldn't find iteration count.\n");
return;
Expand All @@ -1457,7 +1464,6 @@ MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
MachineBasicBlock::iterator InsertPt = LoLoop.StartInsertPt;
MachineInstr *Start = LoLoop.Start;
MachineBasicBlock *MBB = LoLoop.StartInsertBB;
bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
unsigned Opc = LoLoop.getStartOpcode();
MachineOperand &Count = LoLoop.getLoopStartOperand();

Expand All @@ -1466,7 +1472,7 @@ MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {

MIB.addDef(ARM::LR);
MIB.add(Count);
if (!IsDo)
if (!isDo(Start))
MIB.add(Start->getOperand(1));

LoLoop.ToRemove.insert(Start);
Expand Down

0 comments on commit 08d1c2d

Please sign in to comment.