Skip to content

Commit

Permalink
[ARM][ParallelDSP] Change smlad insertion order
Browse files Browse the repository at this point in the history
Instead of inserting everything after the 'root' of the reduction,
insert all instructions as close to their operands as possible. This
can help reduce register pressure.

Differential Revision: https://reviews.llvm.org/D67392

llvm-svn: 374981
  • Loading branch information
sparker-arm committed Oct 16, 2019
1 parent ad76375 commit 1c3ca61
Show file tree
Hide file tree
Showing 13 changed files with 316 additions and 91 deletions.
67 changes: 51 additions & 16 deletions llvm/lib/Target/ARM/ARMParallelDSP.cpp
Expand Up @@ -18,6 +18,7 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/LoopAccessAnalysis.h"
#include "llvm/Analysis/OrderedBasicBlock.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/NoFolder.h"
#include "llvm/Transforms/Scalar.h"
Expand All @@ -42,6 +43,10 @@ static cl::opt<bool>
DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
cl::desc("Disable the ARM Parallel DSP pass"));

static cl::opt<unsigned>
NumLoadLimit("arm-parallel-dsp-load-limit", cl::Hidden, cl::init(16),
cl::desc("Limit the number of loads analysed"));

namespace {
struct MulCandidate;
class Reduction;
Expand Down Expand Up @@ -346,6 +351,7 @@ bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
SmallVector<Instruction*, 8> Writes;
LoadPairs.clear();
WideLoads.clear();
OrderedBasicBlock OrderedBB(BB);

// Collect loads and instruction that may write to memory. For now we only
// record loads which are simple, sign-extended and have a single user.
Expand All @@ -360,38 +366,41 @@ bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
Loads.push_back(Ld);
}

if (Loads.empty() || Loads.size() > NumLoadLimit)
return false;

using InstSet = std::set<Instruction*>;
using DepMap = std::map<Instruction*, InstSet>;
DepMap RAWDeps;

// Record any writes that may alias a load.
const auto Size = LocationSize::unknown();
for (auto Read : Loads) {
for (auto Write : Writes) {
for (auto Write : Writes) {
for (auto Read : Loads) {
MemoryLocation ReadLoc =
MemoryLocation(Read->getPointerOperand(), Size);

if (!isModOrRefSet(intersectModRef(AA->getModRefInfo(Write, ReadLoc),
ModRefInfo::ModRef)))
continue;
if (DT->dominates(Write, Read))
if (OrderedBB.dominates(Write, Read))
RAWDeps[Read].insert(Write);
}
}

// Check whether there's not a write between the two loads which would
// prevent them from being safely merged.
auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) {
LoadInst *Dominator = DT->dominates(Base, Offset) ? Base : Offset;
LoadInst *Dominated = DT->dominates(Base, Offset) ? Offset : Base;
LoadInst *Dominator = OrderedBB.dominates(Base, Offset) ? Base : Offset;
LoadInst *Dominated = OrderedBB.dominates(Base, Offset) ? Offset : Base;

if (RAWDeps.count(Dominated)) {
InstSet &WritesBefore = RAWDeps[Dominated];

for (auto Before : WritesBefore) {
// We can't move the second load backward, past a write, to merge
// with the first load.
if (DT->dominates(Dominator, Before))
if (OrderedBB.dominates(Dominator, Before))
return false;
}
}
Expand All @@ -401,7 +410,7 @@ bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
// Record base, offset load pairs.
for (auto *Base : Loads) {
for (auto *Offset : Loads) {
if (Base == Offset)
if (Base == Offset || OffsetLoads.count(Offset))
continue;

if (AreSequentialAccesses<LoadInst>(Base, Offset, *DL, *SE) &&
Expand Down Expand Up @@ -613,7 +622,6 @@ bool ARMParallelDSP::CreateParallelPairs(Reduction &R) {
return !R.getMulPairs().empty();
}


void ARMParallelDSP::InsertParallelMACs(Reduction &R) {

auto CreateSMLAD = [&](LoadInst* WideLd0, LoadInst *WideLd1,
Expand All @@ -633,39 +641,57 @@ void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);

IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
++BasicBlock::iterator(InsertAfter));
BasicBlock::iterator(InsertAfter));
Instruction *Call = Builder.CreateCall(SMLAD, Args);
NumSMLAD++;
return Call;
};

Instruction *InsertAfter = R.getRoot();
// Return the instruction after the dominated instruction.
auto GetInsertPoint = [this](Value *A, Value *B) {
assert((isa<Instruction>(A) || isa<Instruction>(B)) &&
"expected at least one instruction");

Value *V = nullptr;
if (!isa<Instruction>(A))
V = B;
else if (!isa<Instruction>(B))
V = A;
else
V = DT->dominates(cast<Instruction>(A), cast<Instruction>(B)) ? B : A;

return &*++BasicBlock::iterator(cast<Instruction>(V));
};

Value *Acc = R.getAccumulator();

// For any muls that were discovered but not paired, accumulate their values
// as before.
IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
++BasicBlock::iterator(InsertAfter));
IRBuilder<NoFolder> Builder(R.getRoot()->getParent());
MulCandList &MulCands = R.getMuls();
for (auto &MulCand : MulCands) {
if (MulCand->Paired)
continue;

Value *Mul = MulCand->Root;
Instruction *Mul = cast<Instruction>(MulCand->Root);
LLVM_DEBUG(dbgs() << "Accumulating unpaired mul: " << *Mul << "\n");

if (R.getType() != Mul->getType()) {
assert(R.is64Bit() && "expected 64-bit result");
Mul = Builder.CreateSExt(Mul, R.getType());
Builder.SetInsertPoint(&*++BasicBlock::iterator(Mul));
Mul = cast<Instruction>(Builder.CreateSExt(Mul, R.getRoot()->getType()));
}

if (!Acc) {
Acc = Mul;
continue;
}

// If Acc is the original incoming value to the reduction, it could be a
// phi. But the phi will dominate Mul, meaning that Mul will be the
// insertion point.
Builder.SetInsertPoint(GetInsertPoint(Mul, Acc));
Acc = Builder.CreateAdd(Mul, Acc);
InsertAfter = cast<Instruction>(Acc);
}

if (!Acc) {
Expand All @@ -677,6 +703,14 @@ void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
Acc = Builder.CreateSExt(Acc, R.getType());
}

// Roughly sort the mul pairs in their program order.
OrderedBasicBlock OrderedBB(R.getRoot()->getParent());
llvm::sort(R.getMulPairs(), [&OrderedBB](auto &PairA, auto &PairB) {
const Instruction *A = PairA.first->Root;
const Instruction *B = PairB.first->Root;
return OrderedBB.dominates(A, B);
});

IntegerType *Ty = IntegerType::get(M->getContext(), 32);
for (auto &Pair : R.getMulPairs()) {
MulCandidate *LHSMul = Pair.first;
Expand All @@ -688,8 +722,9 @@ void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
LoadInst *WideRHS = WideLoads.count(BaseRHS) ?
WideLoads[BaseRHS]->getLoad() : CreateWideLoad(RHSMul->VecLd, Ty);

Instruction *InsertAfter = GetInsertPoint(WideLHS, WideRHS);
InsertAfter = GetInsertPoint(InsertAfter, Acc);
Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter);
InsertAfter = cast<Instruction>(Acc);
}
R.UpdateRoot(cast<Instruction>(Acc));
}
Expand Down
159 changes: 159 additions & 0 deletions llvm/test/CodeGen/ARM/ParallelDSP/blocks.ll
Expand Up @@ -134,3 +134,162 @@ bb.1:
ret i32 %res
}

; TODO: Four smlads should be generated here, but mul.0 and mul.3 remain as
; scalars.
; CHECK-LABEL: num_load_limit
; CHECK: call i32 @llvm.arm.smlad
; CHECK: call i32 @llvm.arm.smlad
; CHECK: call i32 @llvm.arm.smlad
; CHECK-NOT: call i32 @llvm.arm.smlad
define i32 @num_load_limit(i16* %a, i16* %b, i32 %acc) {
entry:
%ld.a.0 = load i16, i16* %a
%sext.a.0 = sext i16 %ld.a.0 to i32
%ld.b.0 = load i16, i16* %b
%sext.b.0 = sext i16 %ld.b.0 to i32
%mul.0 = mul i32 %sext.a.0, %sext.b.0
%addr.a.1 = getelementptr i16, i16* %a, i32 1
%addr.b.1 = getelementptr i16, i16* %b, i32 1
%ld.a.1 = load i16, i16* %addr.a.1
%sext.a.1 = sext i16 %ld.a.1 to i32
%ld.b.1 = load i16, i16* %addr.b.1
%sext.b.1 = sext i16 %ld.b.1 to i32
%mul.1 = mul i32 %sext.a.1, %sext.b.1
%add.0 = add i32 %mul.0, %mul.1

%addr.a.2 = getelementptr i16, i16* %a, i32 2
%addr.b.2 = getelementptr i16, i16* %b, i32 2
%ld.a.2 = load i16, i16* %addr.a.2
%sext.a.2 = sext i16 %ld.a.2 to i32
%ld.b.2 = load i16, i16* %addr.b.2
%sext.b.2 = sext i16 %ld.b.2 to i32
%mul.2 = mul i32 %sext.a.0, %sext.b.0
%addr.a.3 = getelementptr i16, i16* %a, i32 3
%addr.b.3 = getelementptr i16, i16* %b, i32 3
%ld.a.3 = load i16, i16* %addr.a.3
%sext.a.3 = sext i16 %ld.a.3 to i32
%ld.b.3 = load i16, i16* %addr.b.3
%sext.b.3 = sext i16 %ld.b.3 to i32
%mul.3 = mul i32 %sext.a.1, %sext.b.3
%add.3 = add i32 %mul.2, %mul.3

%addr.a.4 = getelementptr i16, i16* %a, i32 4
%addr.b.4 = getelementptr i16, i16* %b, i32 4
%ld.a.4 = load i16, i16* %addr.a.4
%sext.a.4 = sext i16 %ld.a.4 to i32
%ld.b.4 = load i16, i16* %addr.b.4
%sext.b.4 = sext i16 %ld.b.4 to i32
%mul.4 = mul i32 %sext.a.4, %sext.b.4
%addr.a.5 = getelementptr i16, i16* %a, i32 5
%addr.b.5 = getelementptr i16, i16* %b, i32 5
%ld.a.5 = load i16, i16* %addr.a.5
%sext.a.5 = sext i16 %ld.a.5 to i32
%ld.b.5 = load i16, i16* %addr.b.5
%sext.b.5 = sext i16 %ld.b.5 to i32
%mul.5 = mul i32 %sext.a.5, %sext.b.5
%add.5 = add i32 %mul.4, %mul.5

%addr.a.6 = getelementptr i16, i16* %a, i32 6
%addr.b.6 = getelementptr i16, i16* %b, i32 6
%ld.a.6 = load i16, i16* %addr.a.6
%sext.a.6 = sext i16 %ld.a.6 to i32
%ld.b.6 = load i16, i16* %addr.b.6
%sext.b.6 = sext i16 %ld.b.6 to i32
%mul.6 = mul i32 %sext.a.6, %sext.b.6
%addr.a.7 = getelementptr i16, i16* %a, i32 7
%addr.b.7 = getelementptr i16, i16* %b, i32 7
%ld.a.7 = load i16, i16* %addr.a.7
%sext.a.7 = sext i16 %ld.a.7 to i32
%ld.b.7 = load i16, i16* %addr.b.7
%sext.b.7 = sext i16 %ld.b.7 to i32
%mul.7 = mul i32 %sext.a.7, %sext.b.7
%add.7 = add i32 %mul.6, %mul.7

%add.10 = add i32 %add.7, %add.5
%add.11 = add i32 %add.3, %add.0
%add.12 = add i32 %add.10, %add.11
%res = add i32 %add.12, %acc
ret i32 %res
}

; CHECK-LABEL: too_many_loads
; CHECK-NOT: call i32 @llvm.arm.smlad
define i32 @too_many_loads(i16* %a, i16* %b, i32 %acc) {
entry:
%ld.a.0 = load i16, i16* %a
%sext.a.0 = sext i16 %ld.a.0 to i32
%ld.b.0 = load i16, i16* %b
%sext.b.0 = sext i16 %ld.b.0 to i32
%mul.0 = mul i32 %sext.a.0, %sext.b.0
%addr.a.1 = getelementptr i16, i16* %a, i32 1
%addr.b.1 = getelementptr i16, i16* %b, i32 1
%ld.a.1 = load i16, i16* %addr.a.1
%sext.a.1 = sext i16 %ld.a.1 to i32
%ld.b.1 = load i16, i16* %addr.b.1
%sext.b.1 = sext i16 %ld.b.1 to i32
%mul.1 = mul i32 %sext.a.1, %sext.b.1
%add.0 = add i32 %mul.0, %mul.1

%addr.a.2 = getelementptr i16, i16* %a, i32 2
%addr.b.2 = getelementptr i16, i16* %b, i32 2
%ld.a.2 = load i16, i16* %addr.a.2
%sext.a.2 = sext i16 %ld.a.2 to i32
%ld.b.2 = load i16, i16* %addr.b.2
%sext.b.2 = sext i16 %ld.b.2 to i32
%mul.2 = mul i32 %sext.a.0, %sext.b.0
%addr.a.3 = getelementptr i16, i16* %a, i32 3
%addr.b.3 = getelementptr i16, i16* %b, i32 3
%ld.a.3 = load i16, i16* %addr.a.3
%sext.a.3 = sext i16 %ld.a.3 to i32
%ld.b.3 = load i16, i16* %addr.b.3
%sext.b.3 = sext i16 %ld.b.3 to i32
%mul.3 = mul i32 %sext.a.1, %sext.b.3
%add.3 = add i32 %mul.2, %mul.3

%addr.a.4 = getelementptr i16, i16* %a, i32 4
%addr.b.4 = getelementptr i16, i16* %b, i32 4
%ld.a.4 = load i16, i16* %addr.a.4
%sext.a.4 = sext i16 %ld.a.4 to i32
%ld.b.4 = load i16, i16* %addr.b.4
%sext.b.4 = sext i16 %ld.b.4 to i32
%mul.4 = mul i32 %sext.a.4, %sext.b.4
%addr.a.5 = getelementptr i16, i16* %a, i32 5
%addr.b.5 = getelementptr i16, i16* %b, i32 5
%ld.a.5 = load i16, i16* %addr.a.5
%sext.a.5 = sext i16 %ld.a.5 to i32
%ld.b.5 = load i16, i16* %addr.b.5
%sext.b.5 = sext i16 %ld.b.5 to i32
%mul.5 = mul i32 %sext.a.5, %sext.b.5
%add.5 = add i32 %mul.4, %mul.5

%addr.a.6 = getelementptr i16, i16* %a, i32 6
%addr.b.6 = getelementptr i16, i16* %b, i32 6
%ld.a.6 = load i16, i16* %addr.a.6
%sext.a.6 = sext i16 %ld.a.6 to i32
%ld.b.6 = load i16, i16* %addr.b.6
%sext.b.6 = sext i16 %ld.b.6 to i32
%mul.6 = mul i32 %sext.a.6, %sext.b.6
%addr.a.7 = getelementptr i16, i16* %a, i32 7
%addr.b.7 = getelementptr i16, i16* %b, i32 7
%ld.a.7 = load i16, i16* %addr.a.7
%sext.a.7 = sext i16 %ld.a.7 to i32
%ld.b.7 = load i16, i16* %addr.b.7
%sext.b.7 = sext i16 %ld.b.7 to i32
%mul.7 = mul i32 %sext.a.7, %sext.b.7
%add.7 = add i32 %mul.6, %mul.7

%addr.a.8 = getelementptr i16, i16* %a, i32 7
%addr.b.8 = getelementptr i16, i16* %b, i32 7
%ld.a.8 = load i16, i16* %addr.a.8
%sext.a.8 = sext i16 %ld.a.8 to i32
%ld.b.8 = load i16, i16* %addr.b.8
%sext.b.8 = sext i16 %ld.b.8 to i32
%mul.8 = mul i32 %sext.a.8, %sext.b.8

%add.10 = add i32 %add.7, %add.5
%add.11 = add i32 %add.3, %add.0
%add.12 = add i32 %add.10, %add.11
%add.13 = add i32 %add.12, %acc
%res = add i32 %add.13, %mul.8
ret i32 %res
}

0 comments on commit 1c3ca61

Please sign in to comment.