@@ -149,10 +149,10 @@ class MVEGatherScatterLowering : public FunctionPass {
149
149
bool optimiseOffsets (Value *Offsets, BasicBlock *BB, LoopInfo *LI);
150
150
// Pushes the given add out of the loop
151
151
void pushOutAdd (PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
152
- // Pushes the given mul out of the loop
153
- void pushOutMul ( PHINode *&Phi, Value *IncrementPerRound,
154
- Value *OffsSecondOperand, unsigned LoopIncrement,
155
- IRBuilder<> &Builder);
152
+ // Pushes the given mul or shl out of the loop
153
+ void pushOutMulShl ( unsigned Opc, PHINode *&Phi, Value *IncrementPerRound,
154
+ Value *OffsSecondOperand, unsigned LoopIncrement,
155
+ IRBuilder<> &Builder);
156
156
};
157
157
158
158
} // end anonymous namespace
@@ -342,7 +342,8 @@ Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
342
342
343
343
const Instruction *I = cast<Instruction>(V);
344
344
if (I->getOpcode () == Instruction::Add ||
345
- I->getOpcode () == Instruction::Mul) {
345
+ I->getOpcode () == Instruction::Mul ||
346
+ I->getOpcode () == Instruction::Shl) {
346
347
Optional<int64_t > Op0 = getIfConst (I->getOperand (0 ));
347
348
Optional<int64_t > Op1 = getIfConst (I->getOperand (1 ));
348
349
if (!Op0 || !Op1)
@@ -351,6 +352,8 @@ Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
351
352
return Optional<int64_t >{Op0.getValue () + Op1.getValue ()};
352
353
if (I->getOpcode () == Instruction::Mul)
353
354
return Optional<int64_t >{Op0.getValue () * Op1.getValue ()};
355
+ if (I->getOpcode () == Instruction::Shl)
356
+ return Optional<int64_t >{Op0.getValue () << Op1.getValue ()};
354
357
}
355
358
return Optional<int64_t >{};
356
359
}
@@ -888,11 +891,11 @@ void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
888
891
Phi->removeIncomingValue (StartIndex);
889
892
}
890
893
891
- void MVEGatherScatterLowering::pushOutMul ( PHINode *&Phi,
892
- Value *IncrementPerRound,
893
- Value *OffsSecondOperand,
894
- unsigned LoopIncrement,
895
- IRBuilder<> &Builder) {
894
+ void MVEGatherScatterLowering::pushOutMulShl ( unsigned Opcode, PHINode *&Phi,
895
+ Value *IncrementPerRound,
896
+ Value *OffsSecondOperand,
897
+ unsigned LoopIncrement,
898
+ IRBuilder<> &Builder) {
896
899
LLVM_DEBUG (dbgs () << " masked gathers/scatters: optimising mul instruction\n " );
897
900
898
901
// Create a new scalar add outside of the loop and transform it to a splat
@@ -901,12 +904,13 @@ void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
901
904
Phi->getIncomingBlock (LoopIncrement == 1 ? 0 : 1 )->back ());
902
905
903
906
// Create a new index
904
- Value *StartIndex = BinaryOperator::Create (
905
- Instruction::Mul, Phi->getIncomingValue (LoopIncrement == 1 ? 0 : 1 ),
906
- OffsSecondOperand, " PushedOutMul" , InsertionPoint);
907
+ Value *StartIndex =
908
+ BinaryOperator::Create ((Instruction::BinaryOps)Opcode,
909
+ Phi->getIncomingValue (LoopIncrement == 1 ? 0 : 1 ),
910
+ OffsSecondOperand, " PushedOutMul" , InsertionPoint);
907
911
908
912
Instruction *Product =
909
- BinaryOperator::Create (Instruction::Mul , IncrementPerRound,
913
+ BinaryOperator::Create (( Instruction::BinaryOps)Opcode , IncrementPerRound,
910
914
OffsSecondOperand, " Product" , InsertionPoint);
911
915
// Increment NewIndex by Product instead of the multiplication
912
916
Instruction *NewIncrement = BinaryOperator::Create (
@@ -936,7 +940,8 @@ static bool hasAllGatScatUsers(Instruction *I) {
936
940
return Gatscat;
937
941
} else {
938
942
unsigned OpCode = cast<Instruction>(U)->getOpcode ();
939
- if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) &&
943
+ if ((OpCode == Instruction::Add || OpCode == Instruction::Mul ||
944
+ OpCode == Instruction::Shl) &&
940
945
hasAllGatScatUsers (cast<Instruction>(U))) {
941
946
continue ;
942
947
}
@@ -956,7 +961,8 @@ bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
956
961
return false ;
957
962
Instruction *Offs = cast<Instruction>(Offsets);
958
963
if (Offs->getOpcode () != Instruction::Add &&
959
- Offs->getOpcode () != Instruction::Mul)
964
+ Offs->getOpcode () != Instruction::Mul &&
965
+ Offs->getOpcode () != Instruction::Shl)
960
966
return false ;
961
967
Loop *L = LI->getLoopFor (BB);
962
968
if (L == nullptr )
@@ -1063,8 +1069,9 @@ bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
1063
1069
pushOutAdd (NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1 );
1064
1070
break ;
1065
1071
case Instruction::Mul:
1066
- pushOutMul (NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock,
1067
- Builder);
1072
+ case Instruction::Shl:
1073
+ pushOutMulShl (Offs->getOpcode (), NewPhi, IncrementPerRound,
1074
+ OffsSecondOperand, IncrementingBlock, Builder);
1068
1075
break ;
1069
1076
default :
1070
1077
return false ;
0 commit comments