Skip to content

Commit d36dd1f

Browse files
committed
[ARM] Push gather/scatter shl index updates out of loops
This teaches the MVE gather scatter lowering pass that SHL is essentially the same as Mul, where we are able to optimize the induction of a gather/scatter address by pushing them out of loops. https://alive2.llvm.org/ce/z/wG4VyT Differential Revision: https://reviews.llvm.org/D112920
1 parent 52615df commit d36dd1f

File tree

3 files changed

+39
-36
lines changed

3 files changed

+39
-36
lines changed

llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,10 @@ class MVEGatherScatterLowering : public FunctionPass {
149149
bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
150150
// Pushes the given add out of the loop
151151
void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
152-
// Pushes the given mul out of the loop
153-
void pushOutMul(PHINode *&Phi, Value *IncrementPerRound,
154-
Value *OffsSecondOperand, unsigned LoopIncrement,
155-
IRBuilder<> &Builder);
152+
// Pushes the given mul or shl out of the loop
153+
void pushOutMulShl(unsigned Opc, PHINode *&Phi, Value *IncrementPerRound,
154+
Value *OffsSecondOperand, unsigned LoopIncrement,
155+
IRBuilder<> &Builder);
156156
};
157157

158158
} // end anonymous namespace
@@ -342,7 +342,8 @@ Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
342342

343343
const Instruction *I = cast<Instruction>(V);
344344
if (I->getOpcode() == Instruction::Add ||
345-
I->getOpcode() == Instruction::Mul) {
345+
I->getOpcode() == Instruction::Mul ||
346+
I->getOpcode() == Instruction::Shl) {
346347
Optional<int64_t> Op0 = getIfConst(I->getOperand(0));
347348
Optional<int64_t> Op1 = getIfConst(I->getOperand(1));
348349
if (!Op0 || !Op1)
@@ -351,6 +352,8 @@ Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
351352
return Optional<int64_t>{Op0.getValue() + Op1.getValue()};
352353
if (I->getOpcode() == Instruction::Mul)
353354
return Optional<int64_t>{Op0.getValue() * Op1.getValue()};
355+
if (I->getOpcode() == Instruction::Shl)
356+
return Optional<int64_t>{Op0.getValue() << Op1.getValue()};
354357
}
355358
return Optional<int64_t>{};
356359
}
@@ -888,11 +891,11 @@ void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
888891
Phi->removeIncomingValue(StartIndex);
889892
}
890893

891-
void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
892-
Value *IncrementPerRound,
893-
Value *OffsSecondOperand,
894-
unsigned LoopIncrement,
895-
IRBuilder<> &Builder) {
894+
void MVEGatherScatterLowering::pushOutMulShl(unsigned Opcode, PHINode *&Phi,
895+
Value *IncrementPerRound,
896+
Value *OffsSecondOperand,
897+
unsigned LoopIncrement,
898+
IRBuilder<> &Builder) {
896899
LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
897900

898901
// Create a new scalar add outside of the loop and transform it to a splat
@@ -901,12 +904,13 @@ void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
901904
Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
902905

903906
// Create a new index
904-
Value *StartIndex = BinaryOperator::Create(
905-
Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
906-
OffsSecondOperand, "PushedOutMul", InsertionPoint);
907+
Value *StartIndex =
908+
BinaryOperator::Create((Instruction::BinaryOps)Opcode,
909+
Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
910+
OffsSecondOperand, "PushedOutMul", InsertionPoint);
907911

908912
Instruction *Product =
909-
BinaryOperator::Create(Instruction::Mul, IncrementPerRound,
913+
BinaryOperator::Create((Instruction::BinaryOps)Opcode, IncrementPerRound,
910914
OffsSecondOperand, "Product", InsertionPoint);
911915
// Increment NewIndex by Product instead of the multiplication
912916
Instruction *NewIncrement = BinaryOperator::Create(
@@ -936,7 +940,8 @@ static bool hasAllGatScatUsers(Instruction *I) {
936940
return Gatscat;
937941
} else {
938942
unsigned OpCode = cast<Instruction>(U)->getOpcode();
939-
if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) &&
943+
if ((OpCode == Instruction::Add || OpCode == Instruction::Mul ||
944+
OpCode == Instruction::Shl) &&
940945
hasAllGatScatUsers(cast<Instruction>(U))) {
941946
continue;
942947
}
@@ -956,7 +961,8 @@ bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
956961
return false;
957962
Instruction *Offs = cast<Instruction>(Offsets);
958963
if (Offs->getOpcode() != Instruction::Add &&
959-
Offs->getOpcode() != Instruction::Mul)
964+
Offs->getOpcode() != Instruction::Mul &&
965+
Offs->getOpcode() != Instruction::Shl)
960966
return false;
961967
Loop *L = LI->getLoopFor(BB);
962968
if (L == nullptr)
@@ -1063,8 +1069,9 @@ bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
10631069
pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
10641070
break;
10651071
case Instruction::Mul:
1066-
pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock,
1067-
Builder);
1072+
case Instruction::Shl:
1073+
pushOutMulShl(Offs->getOpcode(), NewPhi, IncrementPerRound,
1074+
OffsSecondOperand, IncrementingBlock, Builder);
10681075
break;
10691076
default:
10701077
return false;

llvm/test/CodeGen/Thumb2/mve-gather-increment.ll

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,24 +1410,22 @@ define void @shl(i32* nocapture %x, i32* noalias nocapture readonly %y, i32 %n)
14101410
; CHECK-NEXT: .LBB15_1: @ %vector.ph
14111411
; CHECK-NEXT: adr r3, .LCPI15_0
14121412
; CHECK-NEXT: vldrw.u32 q0, [r3]
1413-
; CHECK-NEXT: vmov.i32 q1, #0x4
1413+
; CHECK-NEXT: vadd.i32 q0, q0, r1
14141414
; CHECK-NEXT: dlstp.32 lr, r2
14151415
; CHECK-NEXT: .LBB15_2: @ %vector.body
14161416
; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1
1417-
; CHECK-NEXT: vshl.i32 q2, q0, #2
1418-
; CHECK-NEXT: vadd.i32 q0, q0, q1
1419-
; CHECK-NEXT: vldrw.u32 q3, [r1, q2, uxtw #2]
1420-
; CHECK-NEXT: vstrw.32 q3, [r0], #16
1417+
; CHECK-NEXT: vldrw.u32 q1, [q0, #64]!
1418+
; CHECK-NEXT: vstrw.32 q1, [r0], #16
14211419
; CHECK-NEXT: letp lr, .LBB15_2
14221420
; CHECK-NEXT: @ %bb.3: @ %for.cond.cleanup
14231421
; CHECK-NEXT: pop {r7, pc}
14241422
; CHECK-NEXT: .p2align 4
14251423
; CHECK-NEXT: @ %bb.4:
14261424
; CHECK-NEXT: .LCPI15_0:
1427-
; CHECK-NEXT: .long 0 @ 0x0
1428-
; CHECK-NEXT: .long 1 @ 0x1
1429-
; CHECK-NEXT: .long 2 @ 0x2
1430-
; CHECK-NEXT: .long 3 @ 0x3
1425+
; CHECK-NEXT: .long 4294967232 @ 0xffffffc0
1426+
; CHECK-NEXT: .long 4294967248 @ 0xffffffd0
1427+
; CHECK-NEXT: .long 4294967264 @ 0xffffffe0
1428+
; CHECK-NEXT: .long 4294967280 @ 0xfffffff0
14311429
entry:
14321430
%cmp6 = icmp sgt i32 %n, 0
14331431
br i1 %cmp6, label %vector.ph, label %for.cond.cleanup

llvm/test/CodeGen/Thumb2/mve-scatter-increment.ll

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -236,24 +236,22 @@ define void @shl(i32* nocapture readonly %x, i32* noalias nocapture %y, i32 %n)
236236
; CHECK-NEXT: .LBB4_1: @ %vector.ph
237237
; CHECK-NEXT: adr r3, .LCPI4_0
238238
; CHECK-NEXT: vldrw.u32 q0, [r3]
239-
; CHECK-NEXT: vmov.i32 q1, #0x4
239+
; CHECK-NEXT: vadd.i32 q0, q0, r1
240240
; CHECK-NEXT: dlstp.32 lr, r2
241241
; CHECK-NEXT: .LBB4_2: @ %vector.body
242242
; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1
243-
; CHECK-NEXT: vshl.i32 q3, q0, #2
244-
; CHECK-NEXT: vadd.i32 q0, q0, q1
245-
; CHECK-NEXT: vldrw.u32 q2, [r0], #16
246-
; CHECK-NEXT: vstrw.32 q2, [r1, q3, uxtw #2]
243+
; CHECK-NEXT: vldrw.u32 q1, [r0], #16
244+
; CHECK-NEXT: vstrw.32 q1, [q0, #64]!
247245
; CHECK-NEXT: letp lr, .LBB4_2
248246
; CHECK-NEXT: @ %bb.3: @ %for.cond.cleanup
249247
; CHECK-NEXT: pop {r7, pc}
250248
; CHECK-NEXT: .p2align 4
251249
; CHECK-NEXT: @ %bb.4:
252250
; CHECK-NEXT: .LCPI4_0:
253-
; CHECK-NEXT: .long 0 @ 0x0
254-
; CHECK-NEXT: .long 1 @ 0x1
255-
; CHECK-NEXT: .long 2 @ 0x2
256-
; CHECK-NEXT: .long 3 @ 0x3
251+
; CHECK-NEXT: .long 4294967232 @ 0xffffffc0
252+
; CHECK-NEXT: .long 4294967248 @ 0xffffffd0
253+
; CHECK-NEXT: .long 4294967264 @ 0xffffffe0
254+
; CHECK-NEXT: .long 4294967280 @ 0xfffffff0
257255
entry:
258256
%cmp6 = icmp sgt i32 %n, 0
259257
br i1 %cmp6, label %vector.ph, label %for.cond.cleanup

0 commit comments

Comments
 (0)