diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index f1890e4f5fb95..fc39f4123fac4 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -139,6 +139,7 @@ class VectorCombine { bool foldShuffleOfSelects(Instruction &I); bool foldShuffleOfCastops(Instruction &I); bool foldShuffleOfShuffles(Instruction &I); + bool foldShufflesOfLengthChangingShuffles(Instruction &I); bool foldShuffleOfIntrinsics(Instruction &I); bool foldShuffleToIdentity(Instruction &I); bool foldShuffleFromReductions(Instruction &I); @@ -2877,6 +2878,171 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) { return true; } +/// Try to convert a chain of length-preserving shuffles that are fed by +/// length-changing shuffles from the same source, e.g. a chain of length 3: +/// +/// "shuffle (shuffle (shuffle x, (shuffle y, undef)), +/// (shuffle y, undef)), +// (shuffle y, undef)" +/// +/// into a single shuffle fed by a length-changing shuffle: +/// +/// "shuffle x, (shuffle y, undef)" +/// +/// Such chains arise e.g. from folding extract/insert sequences. +bool VectorCombine::foldShufflesOfLengthChangingShuffles(Instruction &I) { + unsigned ChainLength = 0; + SmallVector Mask; + SmallVector YMask; + InstructionCost OldCost = 0; + InstructionCost NewCost = 0; + FixedVectorType *TrunkType = cast(I.getType()); + Value *Trunk = &I; + unsigned NumTrunkElts = TrunkType->getNumElements(); + FixedVectorType *YType = nullptr; + Value *Y = nullptr; + + for (;;) { + // Match the current trunk against (commutations of) the pattern + // "shuffle trunk', (shuffle y, undef)" + ArrayRef OuterMask; + Value *OuterV0, *OuterV1; + if (ChainLength != 0 && !Trunk->hasOneUse()) + break; + if (!match(Trunk, m_Shuffle(m_Value(OuterV0), m_Value(OuterV1), + m_Mask(OuterMask)))) + break; + if (OuterV0->getType() != TrunkType) { + // This shuffle is not length-preserving, so it cannot be part of the + // chain. + break; + } + + ArrayRef InnerMask0, InnerMask1; + Value *A0, *A1, *B0, *B1; + bool Match0 = + match(OuterV0, m_Shuffle(m_Value(A0), m_Value(B0), m_Mask(InnerMask0))); + bool Match1 = + match(OuterV1, m_Shuffle(m_Value(A1), m_Value(B1), m_Mask(InnerMask1))); + bool Match0Leaf = Match0 && A0->getType() != I.getType(); + bool Match1Leaf = Match1 && A1->getType() != I.getType(); + if (Match0Leaf == Match1Leaf) { + // Only handle the case of exactly one leaf in each step. The "two leaves" + // case is handled by foldShuffleOfShuffles. + break; + } + + SmallVector CommutedOuterMask; + if (Match0Leaf) { + std::swap(OuterV0, OuterV1); + std::swap(InnerMask0, InnerMask1); + std::swap(A0, A1); + std::swap(B0, B1); + llvm::append_range(CommutedOuterMask, OuterMask); + for (int &M : CommutedOuterMask) { + if (M == PoisonMaskElem) + continue; + if (M < (int)NumTrunkElts) + M += NumTrunkElts; + else + M -= NumTrunkElts; + } + OuterMask = CommutedOuterMask; + } + if (!OuterV1->hasOneUse()) + break; + + if (!isa(A1)) { + if (!Y) + Y = A1; + else if (Y != A1) + break; + } + if (!isa(B1)) { + if (!Y) + Y = B1; + else if (Y != B1) + break; + } + + InstructionCost LocalOldCost = + TTI.getInstructionCost(cast(Trunk), CostKind) + + TTI.getInstructionCost(cast(OuterV1), CostKind); + + // Handle the initial (start of chain) case. + if (!ChainLength) { + YType = cast(A1->getType()); + Mask.assign(OuterMask); + YMask.assign(InnerMask1); + OldCost = NewCost = LocalOldCost; + Trunk = OuterV0; + ChainLength++; + continue; + } + + // For the non-root case, first attempt to combine masks. + SmallVector NewYMask(YMask); + bool Valid = true; + for (auto [CombinedM, LeafM] : llvm::zip(NewYMask, InnerMask1)) { + if (LeafM == -1 || CombinedM == LeafM) + continue; + if (CombinedM == -1) { + CombinedM = LeafM; + } else { + Valid = false; + break; + } + } + if (!Valid) + break; + + SmallVector NewMask; + NewMask.reserve(NumTrunkElts); + for (int M : Mask) { + if (M < 0 || M >= (int)NumTrunkElts) + NewMask.push_back(M); + else + NewMask.push_back(OuterMask[M]); + } + + // Break the chain if adding this new step complicates the shuffles such + // that it would increase the new cost by more than the old cost of this + // step. + InstructionCost LocalNewCost = 0; + LocalNewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, + TrunkType, YType, NewYMask, CostKind); + LocalNewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, + TrunkType, TrunkType, NewMask, CostKind); + + if (LocalNewCost >= NewCost && LocalOldCost < LocalNewCost - NewCost) + break; + + LLVM_DEBUG({ + if (ChainLength == 1) { + dbgs() << "Found chain of shuffles fed by length-changing shuffles: " + << I << '\n'; + } + dbgs() << " next chain link: " << *Trunk << '\n' + << " old cost: " << (OldCost + LocalOldCost) + << " new cost: " << LocalNewCost << '\n'; + }); + + Mask = NewMask; + YMask = NewYMask; + OldCost += LocalOldCost; + NewCost = LocalNewCost; + Trunk = OuterV0; + ChainLength++; + } + if (ChainLength <= 1) + return false; + + Value *Leaf = Builder.CreateShuffleVector(Y, PoisonValue::get(YType), YMask); + Value *Root = Builder.CreateShuffleVector(Trunk, Leaf, Mask); + replaceValue(I, *Root); + return true; +} + /// Try to convert /// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)". bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) { @@ -4718,6 +4884,8 @@ bool VectorCombine::run() { return true; if (foldShuffleOfShuffles(I)) return true; + if (foldShufflesOfLengthChangingShuffles(I)) + return true; if (foldShuffleOfIntrinsics(I)) return true; if (foldSelectShuffle(I)) diff --git a/llvm/test/Transforms/VectorCombine/AMDGPU/extract-insert-i8.ll b/llvm/test/Transforms/VectorCombine/AMDGPU/extract-insert-i8.ll index 7a415f4cb71d0..eaab7199a3cf3 100644 --- a/llvm/test/Transforms/VectorCombine/AMDGPU/extract-insert-i8.ll +++ b/llvm/test/Transforms/VectorCombine/AMDGPU/extract-insert-i8.ll @@ -5,39 +5,14 @@ define <32 x i8> @extract_insert_chain(<8 x i8> %in0, <8 x i8> %in1, <8 x i8> %i ; OPT-LABEL: define <32 x i8> @extract_insert_chain( ; OPT-SAME: <8 x i8> [[IN0:%.*]], <8 x i8> [[IN1:%.*]], <8 x i8> [[IN2:%.*]], <8 x i8> [[IN3:%.*]]) #[[ATTR0:[0-9]+]] { ; OPT-NEXT: [[ENTRY:.*:]] -; OPT-NEXT: [[O_1_7:%.*]] = shufflevector <8 x i8> [[IN0]], <8 x i8> [[IN1]], <32 x i32> -; OPT-NEXT: [[TMP1:%.*]] = shufflevector <8 x i8> [[IN2]], <8 x i8> poison, <32 x i32> -; OPT-NEXT: [[O_2_0:%.*]] = shufflevector <32 x i8> [[O_1_7]], <32 x i8> [[TMP1]], <32 x i32> -; OPT-NEXT: [[TMP8:%.*]] = shufflevector <8 x i8> [[IN2]], <8 x i8> poison, <32 x i32> -; OPT-NEXT: [[O_2_1:%.*]] = shufflevector <32 x i8> [[O_2_0]], <32 x i8> [[TMP8]], <32 x i32> -; OPT-NEXT: [[TMP16:%.*]] = shufflevector <8 x i8> [[IN2]], <8 x i8> poison, <32 x i32> -; OPT-NEXT: [[O_2_2:%.*]] = shufflevector <32 x i8> [[O_2_1]], <32 x i8> [[TMP16]], <32 x i32> -; OPT-NEXT: [[TMP3:%.*]] = shufflevector <8 x i8> [[IN2]], <8 x i8> poison, <32 x i32> -; OPT-NEXT: [[O_2_3:%.*]] = shufflevector <32 x i8> [[O_2_2]], <32 x i8> [[TMP3]], <32 x i32> -; OPT-NEXT: [[TMP4:%.*]] = shufflevector <8 x i8> [[IN2]], <8 x i8> poison, <32 x i32> -; OPT-NEXT: [[O_2_4:%.*]] = shufflevector <32 x i8> [[O_2_3]], <32 x i8> [[TMP4]], <32 x i32> -; OPT-NEXT: [[TMP5:%.*]] = shufflevector <8 x i8> [[IN2]], <8 x i8> poison, <32 x i32> -; OPT-NEXT: [[O_2_5:%.*]] = shufflevector <32 x i8> [[O_2_4]], <32 x i8> [[TMP5]], <32 x i32> -; OPT-NEXT: [[TMP6:%.*]] = shufflevector <8 x i8> [[IN2]], <8 x i8> poison, <32 x i32> -; OPT-NEXT: [[O_2_6:%.*]] = shufflevector <32 x i8> [[O_2_5]], <32 x i8> [[TMP6]], <32 x i32> -; OPT-NEXT: [[TMP7:%.*]] = shufflevector <8 x i8> [[IN2]], <8 x i8> poison, <32 x i32> -; OPT-NEXT: [[O_2_7:%.*]] = shufflevector <32 x i8> [[O_2_6]], <32 x i8> [[TMP7]], <32 x i32> -; OPT-NEXT: [[TMP2:%.*]] = shufflevector <8 x i8> [[IN3]], <8 x i8> poison, <32 x i32> -; OPT-NEXT: [[O_3_0:%.*]] = shufflevector <32 x i8> [[O_2_7]], <32 x i8> [[TMP2]], <32 x i32> -; OPT-NEXT: [[TMP9:%.*]] = shufflevector <8 x i8> [[IN3]], <8 x i8> poison, <32 x i32> -; OPT-NEXT: [[O_3_1:%.*]] = shufflevector <32 x i8> [[O_3_0]], <32 x i8> [[TMP9]], <32 x i32> -; OPT-NEXT: [[TMP10:%.*]] = shufflevector <8 x i8> [[IN3]], <8 x i8> poison, <32 x i32> -; OPT-NEXT: [[O_3_2:%.*]] = shufflevector <32 x i8> [[O_3_1]], <32 x i8> [[TMP10]], <32 x i32> -; OPT-NEXT: [[TMP11:%.*]] = shufflevector <8 x i8> [[IN3]], <8 x i8> poison, <32 x i32> -; OPT-NEXT: [[O_3_3:%.*]] = shufflevector <32 x i8> [[O_3_2]], <32 x i8> [[TMP11]], <32 x i32> -; OPT-NEXT: [[TMP12:%.*]] = shufflevector <8 x i8> [[IN3]], <8 x i8> poison, <32 x i32> -; OPT-NEXT: [[O_3_4:%.*]] = shufflevector <32 x i8> [[O_3_3]], <32 x i8> [[TMP12]], <32 x i32> -; OPT-NEXT: [[TMP13:%.*]] = shufflevector <8 x i8> [[IN3]], <8 x i8> poison, <32 x i32> -; OPT-NEXT: [[O_3_5:%.*]] = shufflevector <32 x i8> [[O_3_4]], <32 x i8> [[TMP13]], <32 x i32> -; OPT-NEXT: [[TMP14:%.*]] = shufflevector <8 x i8> [[IN3]], <8 x i8> poison, <32 x i32> -; OPT-NEXT: [[O_3_6:%.*]] = shufflevector <32 x i8> [[O_3_5]], <32 x i8> [[TMP14]], <32 x i32> -; OPT-NEXT: [[TMP15:%.*]] = shufflevector <8 x i8> [[IN3]], <8 x i8> poison, <32 x i32> -; OPT-NEXT: [[O_3_7:%.*]] = shufflevector <32 x i8> [[O_3_6]], <32 x i8> [[TMP15]], <32 x i32> +; OPT-NEXT: [[TMP0:%.*]] = shufflevector <8 x i8> [[IN0]], <8 x i8> poison, <32 x i32> +; OPT-NEXT: [[O_0_7:%.*]] = shufflevector <32 x i8> poison, <32 x i8> [[TMP0]], <32 x i32> +; OPT-NEXT: [[TMP1:%.*]] = shufflevector <8 x i8> [[IN1]], <8 x i8> poison, <32 x i32> +; OPT-NEXT: [[O_1_7:%.*]] = shufflevector <32 x i8> [[O_0_7]], <32 x i8> [[TMP1]], <32 x i32> +; OPT-NEXT: [[TMP2:%.*]] = shufflevector <8 x i8> [[IN2]], <8 x i8> poison, <32 x i32> +; OPT-NEXT: [[O_2_7:%.*]] = shufflevector <32 x i8> [[O_1_7]], <32 x i8> [[TMP2]], <32 x i32> +; OPT-NEXT: [[TMP3:%.*]] = shufflevector <8 x i8> [[IN3]], <8 x i8> poison, <32 x i32> +; OPT-NEXT: [[O_3_7:%.*]] = shufflevector <32 x i8> [[O_2_7]], <32 x i8> [[TMP3]], <32 x i32> ; OPT-NEXT: ret <32 x i8> [[O_3_7]] ; entry: