Skip to content
Open
31 changes: 31 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUPromoteAlloca.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Utils/SSAUpdater.h"

Expand Down Expand Up @@ -644,6 +645,36 @@ static Value *promoteAllocaUserToVector(Instruction *Inst, const DataLayout &DL,
auto *SubVecTy = FixedVectorType::get(VecEltTy, NumLoadedElts);
assert(DL.getTypeStoreSize(SubVecTy) == DL.getTypeStoreSize(AccessTy));

// If idx is dynamic, then sandwich load with bitcasts.
// ie. VectorTy SubVecTy AccessTy
// <64 x i8> -> <16 x i8> <8 x i16>
// <64 x i8> -> <4 x i128> -> i128 -> <8 x i16>
// Extracting subvector with dynamic index has very large expansion in
// the amdgpu backend. Limit to pow2.
FixedVectorType *VectorTy = AA.Vector.Ty;
uint64_t NumBits = DL.getTypeStoreSize(SubVecTy) * 8u;
uint64_t LoadAlign = cast<LoadInst>(Inst)->getAlign().value();
bool IsAlignedLoad = NumBits <= (LoadAlign * 8u);
unsigned TotalNumElts = VectorTy->getNumElements();
bool IsProperlyDivisible = TotalNumElts % NumLoadedElts == 0;
if (!isa<ConstantInt>(Index) &&
llvm::isPowerOf2_32(SubVecTy->getNumElements()) &&
IsProperlyDivisible && IsAlignedLoad) {
IntegerType *NewElemTy = Builder.getIntNTy(NumBits);
const unsigned NewNumElts =
DL.getTypeStoreSize(VectorTy) * 8u / NumBits;
const unsigned LShrAmt = llvm::Log2_32(SubVecTy->getNumElements());
FixedVectorType *BitCastTy =
FixedVectorType::get(NewElemTy, NewNumElts);
Value *BCVal = Builder.CreateBitCast(CurVal, BitCastTy);
Value *NewIdx = Builder.CreateLShr(
Index, ConstantInt::get(Index->getType(), LShrAmt));
Value *ExtVal = Builder.CreateExtractElement(BCVal, NewIdx);
Value *BCOut = Builder.CreateBitCast(ExtVal, AccessTy);
Inst->replaceAllUsesWith(BCOut);
return nullptr;
}

Value *SubVec = PoisonValue::get(SubVecTy);
for (unsigned K = 0; K < NumLoadedElts; ++K) {
Value *CurIdx =
Expand Down
Loading