From 67e726a2f73964740e319d554c354a4227f29375 Mon Sep 17 00:00:00 2001 From: Alexey Bataev Date: Wed, 1 May 2024 07:32:33 -0400 Subject: [PATCH] [SLP]Transform stores + reverse to strided stores with stride -1, if profitable. Adds transformation of consecutive vector store + reverse to strided stores with stride -1, if it is profitable Reviewers: RKSimon, preames Reviewed By: RKSimon Pull Request: https://github.com/llvm/llvm-project/pull/90464 --- .../Transforms/Vectorize/SLPVectorizer.cpp | 74 +++++++++++++++++-- .../RISCV/strided-stores-vectorized.ll | 31 ++------ 2 files changed, 71 insertions(+), 34 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 681081de13e0..59aa2fa0554f 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -7934,6 +7934,33 @@ void BoUpSLP::transformNodes() { } break; } + case Instruction::Store: { + Type *ScalarTy = + cast(E.getMainOp())->getValueOperand()->getType(); + auto *VecTy = FixedVectorType::get(ScalarTy, E.Scalars.size()); + Align CommonAlignment = computeCommonAlignment(E.Scalars); + // Check if profitable to represent consecutive load + reverse as strided + // load with stride -1. + if (isReverseOrder(E.ReorderIndices) && + TTI->isLegalStridedLoadStore(VecTy, CommonAlignment)) { + SmallVector Mask; + inversePermutation(E.ReorderIndices, Mask); + auto *BaseSI = cast(E.Scalars.back()); + InstructionCost OriginalVecCost = + TTI->getMemoryOpCost(Instruction::Store, VecTy, BaseSI->getAlign(), + BaseSI->getPointerAddressSpace(), CostKind, + TTI::OperandValueInfo()) + + ::getShuffleCost(*TTI, TTI::SK_Reverse, VecTy, Mask, CostKind); + InstructionCost StridedCost = TTI->getStridedMemoryOpCost( + Instruction::Store, VecTy, BaseSI->getPointerOperand(), + /*VariableMask=*/false, CommonAlignment, CostKind, BaseSI); + if (StridedCost < OriginalVecCost) + // Strided load is more profitable than consecutive load + reverse - + // transform the node to strided load. + E.State = TreeEntry::StridedVectorize; + } + break; + } default: break; } @@ -9466,11 +9493,22 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef VectorizedVals, cast(IsReorder ? VL[E->ReorderIndices.front()] : VL0); auto GetVectorCost = [=](InstructionCost CommonCost) { // We know that we can merge the stores. Calculate the cost. - TTI::OperandValueInfo OpInfo = getOperandInfo(E->getOperand(0)); - return TTI->getMemoryOpCost(Instruction::Store, VecTy, BaseSI->getAlign(), - BaseSI->getPointerAddressSpace(), CostKind, - OpInfo) + - CommonCost; + InstructionCost VecStCost; + if (E->State == TreeEntry::StridedVectorize) { + Align CommonAlignment = + computeCommonAlignment(UniqueValues.getArrayRef()); + VecStCost = TTI->getStridedMemoryOpCost( + Instruction::Store, VecTy, BaseSI->getPointerOperand(), + /*VariableMask=*/false, CommonAlignment, CostKind); + } else { + assert(E->State == TreeEntry::Vectorize && + "Expected either strided or consecutive stores."); + TTI::OperandValueInfo OpInfo = getOperandInfo(E->getOperand(0)); + VecStCost = TTI->getMemoryOpCost( + Instruction::Store, VecTy, BaseSI->getAlign(), + BaseSI->getPointerAddressSpace(), CostKind, OpInfo); + } + return VecStCost + CommonCost; }; SmallVector PointerOps(VL.size()); for (auto [I, V] : enumerate(VL)) { @@ -12398,7 +12436,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { bool IsReverseOrder = isReverseOrder(E->ReorderIndices); auto FinalShuffle = [&](Value *V, const TreeEntry *E, VectorType *VecTy) { ShuffleInstructionBuilder ShuffleBuilder(ScalarTy, Builder, *this); - if (E->getOpcode() == Instruction::Store) { + if (E->getOpcode() == Instruction::Store && + E->State == TreeEntry::Vectorize) { ArrayRef Mask = ArrayRef(reinterpret_cast(E->ReorderIndices.begin()), E->ReorderIndices.size()); @@ -12986,8 +13025,27 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { VecValue = FinalShuffle(VecValue, E, VecTy); Value *Ptr = SI->getPointerOperand(); - StoreInst *ST = - Builder.CreateAlignedStore(VecValue, Ptr, SI->getAlign()); + Instruction *ST; + if (E->State == TreeEntry::Vectorize) { + ST = Builder.CreateAlignedStore(VecValue, Ptr, SI->getAlign()); + } else { + assert(E->State == TreeEntry::StridedVectorize && + "Expected either strided or conseutive stores."); + Align CommonAlignment = computeCommonAlignment(E->Scalars); + Type *StrideTy = DL->getIndexType(SI->getPointerOperandType()); + auto *Inst = Builder.CreateIntrinsic( + Intrinsic::experimental_vp_strided_store, + {VecTy, Ptr->getType(), StrideTy}, + {VecValue, Ptr, + ConstantInt::get( + StrideTy, -static_cast(DL->getTypeAllocSize(ScalarTy))), + Builder.getAllOnesMask(VecTy->getElementCount()), + Builder.getInt32(E->Scalars.size())}); + Inst->addParamAttr( + /*ArgNo=*/1, + Attribute::getWithAlignment(Inst->getContext(), CommonAlignment)); + ST = Inst; + } Value *V = propagateMetadata(ST, E->Scalars); diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/strided-stores-vectorized.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/strided-stores-vectorized.ll index 0dfa45da9d87..56e8829b0ec6 100644 --- a/llvm/test/Transforms/SLPVectorizer/RISCV/strided-stores-vectorized.ll +++ b/llvm/test/Transforms/SLPVectorizer/RISCV/strided-stores-vectorized.ll @@ -4,33 +4,12 @@ define void @store_reverse(ptr %p3) { ; CHECK-LABEL: @store_reverse( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = load i64, ptr [[P3:%.*]], align 8 -; CHECK-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 8 -; CHECK-NEXT: [[TMP1:%.*]] = load i64, ptr [[ARRAYIDX1]], align 8 -; CHECK-NEXT: [[SHL:%.*]] = shl i64 [[TMP0]], [[TMP1]] -; CHECK-NEXT: [[ARRAYIDX2:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 7 -; CHECK-NEXT: store i64 [[SHL]], ptr [[ARRAYIDX2]], align 8 -; CHECK-NEXT: [[ARRAYIDX3:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 1 -; CHECK-NEXT: [[TMP2:%.*]] = load i64, ptr [[ARRAYIDX3]], align 8 -; CHECK-NEXT: [[ARRAYIDX4:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 9 -; CHECK-NEXT: [[TMP3:%.*]] = load i64, ptr [[ARRAYIDX4]], align 8 -; CHECK-NEXT: [[SHL5:%.*]] = shl i64 [[TMP2]], [[TMP3]] -; CHECK-NEXT: [[ARRAYIDX6:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 6 -; CHECK-NEXT: store i64 [[SHL5]], ptr [[ARRAYIDX6]], align 8 -; CHECK-NEXT: [[ARRAYIDX7:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 2 -; CHECK-NEXT: [[TMP4:%.*]] = load i64, ptr [[ARRAYIDX7]], align 8 -; CHECK-NEXT: [[ARRAYIDX8:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 10 -; CHECK-NEXT: [[TMP5:%.*]] = load i64, ptr [[ARRAYIDX8]], align 8 -; CHECK-NEXT: [[SHL9:%.*]] = shl i64 [[TMP4]], [[TMP5]] -; CHECK-NEXT: [[ARRAYIDX10:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 5 -; CHECK-NEXT: store i64 [[SHL9]], ptr [[ARRAYIDX10]], align 8 -; CHECK-NEXT: [[ARRAYIDX11:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 3 -; CHECK-NEXT: [[TMP6:%.*]] = load i64, ptr [[ARRAYIDX11]], align 8 -; CHECK-NEXT: [[ARRAYIDX12:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 11 -; CHECK-NEXT: [[TMP7:%.*]] = load i64, ptr [[ARRAYIDX12]], align 8 -; CHECK-NEXT: [[SHL13:%.*]] = shl i64 [[TMP6]], [[TMP7]] +; CHECK-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds i64, ptr [[P3:%.*]], i64 8 ; CHECK-NEXT: [[ARRAYIDX14:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 4 -; CHECK-NEXT: store i64 [[SHL13]], ptr [[ARRAYIDX14]], align 8 +; CHECK-NEXT: [[TMP0:%.*]] = load <4 x i64>, ptr [[P3]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i64>, ptr [[ARRAYIDX1]], align 8 +; CHECK-NEXT: [[TMP2:%.*]] = shl <4 x i64> [[TMP0]], [[TMP1]] +; CHECK-NEXT: call void @llvm.experimental.vp.strided.store.v4i64.p0.i64(<4 x i64> [[TMP2]], ptr align 8 [[ARRAYIDX14]], i64 -8, <4 x i1> , i32 4) ; CHECK-NEXT: ret void ; entry: