Skip to content

Commit

Permalink
[ARM] Push gather/scatter shl index updates out of loops
Browse files Browse the repository at this point in the history
This teaches the MVE gather scatter lowering pass that SHL is
essentially the same as Mul, where we are able to optimize the
induction of a gather/scatter address by pushing them out of loops.
https://alive2.llvm.org/ce/z/wG4VyT

Differential Revision: https://reviews.llvm.org/D112920
  • Loading branch information
davemgreen committed Nov 3, 2021
1 parent 52615df commit d36dd1f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 36 deletions.
43 changes: 25 additions & 18 deletions llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp
Expand Up @@ -149,10 +149,10 @@ class MVEGatherScatterLowering : public FunctionPass {
bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
// Pushes the given add out of the loop
void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
// Pushes the given mul out of the loop
void pushOutMul(PHINode *&Phi, Value *IncrementPerRound,
Value *OffsSecondOperand, unsigned LoopIncrement,
IRBuilder<> &Builder);
// Pushes the given mul or shl out of the loop
void pushOutMulShl(unsigned Opc, PHINode *&Phi, Value *IncrementPerRound,
Value *OffsSecondOperand, unsigned LoopIncrement,
IRBuilder<> &Builder);
};

} // end anonymous namespace
Expand Down Expand Up @@ -342,7 +342,8 @@ Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {

const Instruction *I = cast<Instruction>(V);
if (I->getOpcode() == Instruction::Add ||
I->getOpcode() == Instruction::Mul) {
I->getOpcode() == Instruction::Mul ||
I->getOpcode() == Instruction::Shl) {
Optional<int64_t> Op0 = getIfConst(I->getOperand(0));
Optional<int64_t> Op1 = getIfConst(I->getOperand(1));
if (!Op0 || !Op1)
Expand All @@ -351,6 +352,8 @@ Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
return Optional<int64_t>{Op0.getValue() + Op1.getValue()};
if (I->getOpcode() == Instruction::Mul)
return Optional<int64_t>{Op0.getValue() * Op1.getValue()};
if (I->getOpcode() == Instruction::Shl)
return Optional<int64_t>{Op0.getValue() << Op1.getValue()};
}
return Optional<int64_t>{};
}
Expand Down Expand Up @@ -888,11 +891,11 @@ void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
Phi->removeIncomingValue(StartIndex);
}

void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
Value *IncrementPerRound,
Value *OffsSecondOperand,
unsigned LoopIncrement,
IRBuilder<> &Builder) {
void MVEGatherScatterLowering::pushOutMulShl(unsigned Opcode, PHINode *&Phi,
Value *IncrementPerRound,
Value *OffsSecondOperand,
unsigned LoopIncrement,
IRBuilder<> &Builder) {
LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");

// Create a new scalar add outside of the loop and transform it to a splat
Expand All @@ -901,12 +904,13 @@ void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());

// Create a new index
Value *StartIndex = BinaryOperator::Create(
Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
OffsSecondOperand, "PushedOutMul", InsertionPoint);
Value *StartIndex =
BinaryOperator::Create((Instruction::BinaryOps)Opcode,
Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
OffsSecondOperand, "PushedOutMul", InsertionPoint);

Instruction *Product =
BinaryOperator::Create(Instruction::Mul, IncrementPerRound,
BinaryOperator::Create((Instruction::BinaryOps)Opcode, IncrementPerRound,
OffsSecondOperand, "Product", InsertionPoint);
// Increment NewIndex by Product instead of the multiplication
Instruction *NewIncrement = BinaryOperator::Create(
Expand Down Expand Up @@ -936,7 +940,8 @@ static bool hasAllGatScatUsers(Instruction *I) {
return Gatscat;
} else {
unsigned OpCode = cast<Instruction>(U)->getOpcode();
if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) &&
if ((OpCode == Instruction::Add || OpCode == Instruction::Mul ||
OpCode == Instruction::Shl) &&
hasAllGatScatUsers(cast<Instruction>(U))) {
continue;
}
Expand All @@ -956,7 +961,8 @@ bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
return false;
Instruction *Offs = cast<Instruction>(Offsets);
if (Offs->getOpcode() != Instruction::Add &&
Offs->getOpcode() != Instruction::Mul)
Offs->getOpcode() != Instruction::Mul &&
Offs->getOpcode() != Instruction::Shl)
return false;
Loop *L = LI->getLoopFor(BB);
if (L == nullptr)
Expand Down Expand Up @@ -1063,8 +1069,9 @@ bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
break;
case Instruction::Mul:
pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock,
Builder);
case Instruction::Shl:
pushOutMulShl(Offs->getOpcode(), NewPhi, IncrementPerRound,
OffsSecondOperand, IncrementingBlock, Builder);
break;
default:
return false;
Expand Down
16 changes: 7 additions & 9 deletions llvm/test/CodeGen/Thumb2/mve-gather-increment.ll
Expand Up @@ -1410,24 +1410,22 @@ define void @shl(i32* nocapture %x, i32* noalias nocapture readonly %y, i32 %n)
; CHECK-NEXT: .LBB15_1: @ %vector.ph
; CHECK-NEXT: adr r3, .LCPI15_0
; CHECK-NEXT: vldrw.u32 q0, [r3]
; CHECK-NEXT: vmov.i32 q1, #0x4
; CHECK-NEXT: vadd.i32 q0, q0, r1
; CHECK-NEXT: dlstp.32 lr, r2
; CHECK-NEXT: .LBB15_2: @ %vector.body
; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1
; CHECK-NEXT: vshl.i32 q2, q0, #2
; CHECK-NEXT: vadd.i32 q0, q0, q1
; CHECK-NEXT: vldrw.u32 q3, [r1, q2, uxtw #2]
; CHECK-NEXT: vstrw.32 q3, [r0], #16
; CHECK-NEXT: vldrw.u32 q1, [q0, #64]!
; CHECK-NEXT: vstrw.32 q1, [r0], #16
; CHECK-NEXT: letp lr, .LBB15_2
; CHECK-NEXT: @ %bb.3: @ %for.cond.cleanup
; CHECK-NEXT: pop {r7, pc}
; CHECK-NEXT: .p2align 4
; CHECK-NEXT: @ %bb.4:
; CHECK-NEXT: .LCPI15_0:
; CHECK-NEXT: .long 0 @ 0x0
; CHECK-NEXT: .long 1 @ 0x1
; CHECK-NEXT: .long 2 @ 0x2
; CHECK-NEXT: .long 3 @ 0x3
; CHECK-NEXT: .long 4294967232 @ 0xffffffc0
; CHECK-NEXT: .long 4294967248 @ 0xffffffd0
; CHECK-NEXT: .long 4294967264 @ 0xffffffe0
; CHECK-NEXT: .long 4294967280 @ 0xfffffff0
entry:
%cmp6 = icmp sgt i32 %n, 0
br i1 %cmp6, label %vector.ph, label %for.cond.cleanup
Expand Down
16 changes: 7 additions & 9 deletions llvm/test/CodeGen/Thumb2/mve-scatter-increment.ll
Expand Up @@ -236,24 +236,22 @@ define void @shl(i32* nocapture readonly %x, i32* noalias nocapture %y, i32 %n)
; CHECK-NEXT: .LBB4_1: @ %vector.ph
; CHECK-NEXT: adr r3, .LCPI4_0
; CHECK-NEXT: vldrw.u32 q0, [r3]
; CHECK-NEXT: vmov.i32 q1, #0x4
; CHECK-NEXT: vadd.i32 q0, q0, r1
; CHECK-NEXT: dlstp.32 lr, r2
; CHECK-NEXT: .LBB4_2: @ %vector.body
; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1
; CHECK-NEXT: vshl.i32 q3, q0, #2
; CHECK-NEXT: vadd.i32 q0, q0, q1
; CHECK-NEXT: vldrw.u32 q2, [r0], #16
; CHECK-NEXT: vstrw.32 q2, [r1, q3, uxtw #2]
; CHECK-NEXT: vldrw.u32 q1, [r0], #16
; CHECK-NEXT: vstrw.32 q1, [q0, #64]!
; CHECK-NEXT: letp lr, .LBB4_2
; CHECK-NEXT: @ %bb.3: @ %for.cond.cleanup
; CHECK-NEXT: pop {r7, pc}
; CHECK-NEXT: .p2align 4
; CHECK-NEXT: @ %bb.4:
; CHECK-NEXT: .LCPI4_0:
; CHECK-NEXT: .long 0 @ 0x0
; CHECK-NEXT: .long 1 @ 0x1
; CHECK-NEXT: .long 2 @ 0x2
; CHECK-NEXT: .long 3 @ 0x3
; CHECK-NEXT: .long 4294967232 @ 0xffffffc0
; CHECK-NEXT: .long 4294967248 @ 0xffffffd0
; CHECK-NEXT: .long 4294967264 @ 0xffffffe0
; CHECK-NEXT: .long 4294967280 @ 0xfffffff0
entry:
%cmp6 = icmp sgt i32 %n, 0
br i1 %cmp6, label %vector.ph, label %for.cond.cleanup
Expand Down

0 comments on commit d36dd1f

Please sign in to comment.