diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 299f22d825277..8ad772fdbf1c5 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -4183,6 +4183,59 @@ static bool isAlreadyNarrow(VPValue *VPV) { return RepR && RepR->isSingleScalar(); } +// Convert a wide recipe defining a VPValue \p V feeding an interleave group to +// a narrow variant. +static VPValue * +narrowInterleaveGroupOp(VPValue *V, SmallPtrSetImpl &NarrowedOps) { + auto *R = V->getDefiningRecipe(); + if (!R || NarrowedOps.contains(V)) + return V; + + if (isAlreadyNarrow(V)) + return V; + + if (auto *WideMember0 = dyn_cast(R)) { + for (unsigned Idx = 0, E = WideMember0->getNumOperands(); Idx != E; ++Idx) + WideMember0->setOperand( + Idx, + narrowInterleaveGroupOp(WideMember0->getOperand(Idx), NarrowedOps)); + return V; + } + + if (auto *LoadGroup = dyn_cast(R)) { + // Narrow interleave group to wide load, as transformed VPlan will only + // process one original iteration. + auto *LI = cast(LoadGroup->getInterleaveGroup()->getInsertPos()); + auto *L = new VPWidenLoadRecipe( + *LI, LoadGroup->getAddr(), LoadGroup->getMask(), /*Consecutive=*/true, + /*Reverse=*/false, LI->getAlign(), {}, LoadGroup->getDebugLoc()); + L->insertBefore(LoadGroup); + NarrowedOps.insert(L); + return L; + } + + if (auto *RepR = dyn_cast(R)) { + assert(RepR->isSingleScalar() && + isa(RepR->getUnderlyingInstr()) && + "must be a single scalar load"); + NarrowedOps.insert(RepR); + return RepR; + } + + auto *WideLoad = cast(R); + VPValue *PtrOp = WideLoad->getAddr(); + if (auto *VecPtr = dyn_cast(PtrOp)) + PtrOp = VecPtr->getOperand(0); + // Narrow wide load to uniform scalar load, as transformed VPlan will only + // process one original iteration. + auto *N = new VPReplicateRecipe(&WideLoad->getIngredient(), {PtrOp}, + /*IsUniform*/ true, + /*Mask*/ nullptr, *WideLoad); + N->insertBefore(WideLoad); + NarrowedOps.insert(N); + return N; +} + void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF, unsigned VectorRegWidth) { VPRegionBlock *VectorLoop = Plan.getVectorLoopRegion(); @@ -4284,60 +4337,10 @@ void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF, // Convert InterleaveGroup \p R to a single VPWidenLoadRecipe. SmallPtrSet NarrowedOps; - auto NarrowOp = [&NarrowedOps](VPValue *V) -> VPValue * { - auto *R = V->getDefiningRecipe(); - if (!R || NarrowedOps.contains(V)) - return V; - if (auto *LoadGroup = dyn_cast(R)) { - // Narrow interleave group to wide load, as transformed VPlan will only - // process one original iteration. - auto *LI = - cast(LoadGroup->getInterleaveGroup()->getInsertPos()); - auto *L = new VPWidenLoadRecipe( - *LI, LoadGroup->getAddr(), LoadGroup->getMask(), /*Consecutive=*/true, - /*Reverse=*/false, LI->getAlign(), {}, LoadGroup->getDebugLoc()); - L->insertBefore(LoadGroup); - NarrowedOps.insert(L); - return L; - } - - if (auto *RepR = dyn_cast(R)) { - assert(RepR->isSingleScalar() && - isa(RepR->getUnderlyingInstr()) && - "must be a single scalar load"); - NarrowedOps.insert(RepR); - return RepR; - } - auto *WideLoad = cast(R); - - VPValue *PtrOp = WideLoad->getAddr(); - if (auto *VecPtr = dyn_cast(PtrOp)) - PtrOp = VecPtr->getOperand(0); - // Narrow wide load to uniform scalar load, as transformed VPlan will only - // process one original iteration. - auto *N = new VPReplicateRecipe(&WideLoad->getIngredient(), {PtrOp}, - /*IsUniform*/ true, - /*Mask*/ nullptr, *WideLoad); - N->insertBefore(WideLoad); - NarrowedOps.insert(N); - return N; - }; - // Narrow operation tree rooted at store groups. for (auto *StoreGroup : StoreGroups) { - VPValue *Res = nullptr; - VPValue *Member0 = StoreGroup->getStoredValues()[0]; - if (isAlreadyNarrow(Member0)) { - Res = Member0; - } else if (auto *WideMember0 = - dyn_cast(Member0->getDefiningRecipe())) { - for (unsigned Idx = 0, E = WideMember0->getNumOperands(); Idx != E; ++Idx) - WideMember0->setOperand(Idx, NarrowOp(WideMember0->getOperand(Idx))); - Res = WideMember0; - } else { - Res = NarrowOp(Member0); - } - + VPValue *Res = + narrowInterleaveGroupOp(StoreGroup->getStoredValues()[0], NarrowedOps); auto *SI = cast(StoreGroup->getInterleaveGroup()->getInsertPos()); auto *S = new VPWidenStoreRecipe(