Skip to content

Commit

Permalink
[ARM][LowOverheadLoops] TryRemove helper.
Browse files Browse the repository at this point in the history
Make a helper function that wraps around RDA::isSafeToRemove and
utilises the existing DCE IT block checks.
  • Loading branch information
sparker-arm committed Sep 30, 2020
1 parent 8392685 commit 779a8a0
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 58 deletions.
3 changes: 3 additions & 0 deletions llvm/lib/CodeGen/ReachingDefAnalysis.cpp
Expand Up @@ -635,6 +635,9 @@ void ReachingDefAnalysis::collectKilledOperands(MachineInstr *MI,
InstSet &Dead) const {
Dead.insert(MI);
auto IsDead = [this, &Dead](MachineInstr *Def, int PhysReg) {
if (mayHaveSideEffects(*Def))
return false;

unsigned LiveDefs = 0;
for (auto &MO : Def->operands()) {
if (!isValidRegDef(MO))
Expand Down
126 changes: 68 additions & 58 deletions llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp
Expand Up @@ -505,8 +505,6 @@ namespace {

void Expand(LowOverheadLoop &LoLoop);

void DCE(MachineInstr *MI, SmallPtrSetImpl<MachineInstr *> &ToRemove);

void IterationCountDCE(LowOverheadLoop &LoLoop);
};
}
Expand All @@ -521,6 +519,72 @@ std::map<MachineInstr *,
INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
false, false)

static bool TryRemove(MachineInstr *MI, ReachingDefAnalysis &RDA,
InstSet &ToRemove, InstSet &Ignore) {

// Check that we can remove all of Killed without having to modify any IT
// blocks.
auto WontCorruptITs = [](InstSet &Killed, ReachingDefAnalysis &RDA) {
// Collect the dead code and the MBBs in which they reside.
SmallPtrSet<MachineBasicBlock*, 2> BasicBlocks;
for (auto *Dead : Killed)
BasicBlocks.insert(Dead->getParent());

// Collect IT blocks in all affected basic blocks.
std::map<MachineInstr *, SmallPtrSet<MachineInstr *, 2>> ITBlocks;
for (auto *MBB : BasicBlocks) {
for (auto &IT : *MBB) {
if (IT.getOpcode() != ARM::t2IT)
continue;
RDA.getReachingLocalUses(&IT, ARM::ITSTATE, ITBlocks[&IT]);
}
}

// If we're removing all of the instructions within an IT block, then
// also remove the IT instruction.
SmallPtrSet<MachineInstr *, 2> ModifiedITs;
SmallPtrSet<MachineInstr *, 2> RemoveITs;
for (auto *Dead : Killed) {
if (MachineOperand *MO = Dead->findRegisterUseOperand(ARM::ITSTATE)) {
MachineInstr *IT = RDA.getMIOperand(Dead, *MO);
RemoveITs.insert(IT);
auto &CurrentBlock = ITBlocks[IT];
CurrentBlock.erase(Dead);
if (CurrentBlock.empty())
ModifiedITs.erase(IT);
else
ModifiedITs.insert(IT);
}
}
if (!ModifiedITs.empty())
return false;
Killed.insert(RemoveITs.begin(), RemoveITs.end());
return true;
};

SmallPtrSet<MachineInstr *, 2> Uses;
if (!RDA.isSafeToRemove(MI, Uses, Ignore))
return false;

if (WontCorruptITs(Uses, RDA)) {
ToRemove.insert(Uses.begin(), Uses.end());
LLVM_DEBUG(dbgs() << "ARM Loops: Able to remove: " << *MI
<< " - can also remove:\n";
for (auto *Use : Uses)
dbgs() << " - " << *Use);

SmallPtrSet<MachineInstr*, 4> Killed;
RDA.collectKilledOperands(MI, Killed);
if (WontCorruptITs(Killed, RDA)) {
ToRemove.insert(Killed.begin(), Killed.end());
LLVM_DEBUG(for (auto *Dead : Killed)
dbgs() << " - " << *Dead);
}
return true;
}
return false;
}

bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt) {
if (!StartInsertPt)
return false;
Expand Down Expand Up @@ -669,7 +733,7 @@ bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt) {

Ignore.insert(VCTPs.begin(), VCTPs.end());

if (RDA.isSafeToRemove(Def, ElementChain, Ignore)) {
if (TryRemove(Def, RDA, ElementChain, Ignore)) {
bool FoundSub = false;

for (auto *MI : ElementChain) {
Expand All @@ -683,10 +747,6 @@ bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt) {
} else
return false;
}

LLVM_DEBUG(dbgs() << "ARM Loops: Will remove element count chain:\n";
for (auto *MI : ElementChain)
dbgs() << " - " << *MI);
ToRemove.insert(ElementChain.begin(), ElementChain.end());
}
}
Expand Down Expand Up @@ -1300,52 +1360,6 @@ void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
MI->eraseFromParent();
}

void ARMLowOverheadLoops::DCE(MachineInstr *MI,
SmallPtrSetImpl<MachineInstr *> &ToRemove) {
// Collect the dead code and the MBBs in which they reside.
SmallPtrSet<MachineInstr*, 4> Killed;
RDA->collectKilledOperands(MI, Killed);
SmallPtrSet<MachineBasicBlock*, 2> BasicBlocks;
for (auto *Dead : Killed)
BasicBlocks.insert(Dead->getParent());

// Collect IT blocks in all affected basic blocks.
std::map<MachineInstr *, SmallPtrSet<MachineInstr *, 2>> ITBlocks;
for (auto *MBB : BasicBlocks) {
for (auto &IT : *MBB) {
if (IT.getOpcode() != ARM::t2IT)
continue;
RDA->getReachingLocalUses(&IT, ARM::ITSTATE, ITBlocks[&IT]);
}
}

// If we're removing all of the instructions within an IT block, then
// also remove the IT instruction.
SmallPtrSet<MachineInstr*, 2> ModifiedITs;
for (auto *Dead : Killed) {
if (MachineOperand *MO = Dead->findRegisterUseOperand(ARM::ITSTATE)) {
MachineInstr *IT = RDA->getMIOperand(Dead, *MO);
auto &CurrentBlock = ITBlocks[IT];
CurrentBlock.erase(Dead);
if (CurrentBlock.empty())
ModifiedITs.erase(IT);
else
ModifiedITs.insert(IT);
}
}

// Delete the killed instructions only if we don't have any IT blocks that
// need to be modified because we need to fixup the mask.
// TODO: Handle cases where IT blocks are modified.
if (ModifiedITs.empty()) {
LLVM_DEBUG(dbgs() << "ARM Loops: Will remove iteration count:\n";
for (auto *MI : Killed)
dbgs() << " - " << *MI);
ToRemove.insert(Killed.begin(), Killed.end());
} else
LLVM_DEBUG(dbgs() << "ARM Loops: Would need to modify IT block(s).\n");
}

// Perform dead code elimation on the loop iteration count setup expression.
// If we are tail-predicating, the number of elements to be processed is the
// operand of the VCTP instruction in the vector body, see getCount(), which is
Expand Down Expand Up @@ -1385,11 +1399,7 @@ void ARMLowOverheadLoops::IterationCountDCE(LowOverheadLoop &LoLoop) {
// Collect and remove the users of iteration count.
SmallPtrSet<MachineInstr*, 4> Killed = { LoLoop.Start, LoLoop.Dec,
LoLoop.End, LoLoop.InsertPt };
SmallPtrSet<MachineInstr*, 2> Remove;
if (RDA->isSafeToRemove(Def, Remove, Killed)) {
LoLoop.ToRemove.insert(Remove.begin(), Remove.end());
DCE(Def, LoLoop.ToRemove);
} else
if (!TryRemove(Def, *RDA, LoLoop.ToRemove, Killed))
LLVM_DEBUG(dbgs() << "ARM Loops: Unsafe to remove loop iteration count.\n");
}

Expand Down

0 comments on commit 779a8a0

Please sign in to comment.