diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 7c114b3d5195b2..acf483b4dd4124 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -1167,16 +1167,30 @@ class BoUpSLP { /// \returns the score of placing \p V1 and \p V2 in consecutive lanes. /// Also, checks if \p V1 and \p V2 are compatible with instructions in \p /// MainAltOps. - static int getShallowScore(Value *V1, Value *V2, const DataLayout &DL, - ScalarEvolution &SE, int NumLanes, - ArrayRef MainAltOps, - const TargetTransformInfo *TTI) { + int getShallowScore(Value *V1, Value *V2, Instruction *U1, Instruction *U2, + const DataLayout &DL, ScalarEvolution &SE, int NumLanes, + ArrayRef MainAltOps) { if (V1 == V2) { if (isa(V1)) { + // Retruns true if the users of V1 and V2 won't need to be extracted. + auto AllUsersAreInternal = [NumLanes, U1, U2, this](Value *V1, + Value *V2) { + // Bail out if we have too many uses to save compilation time. + static constexpr unsigned Limit = 8; + if (V1->hasNUsesOrMore(Limit) || V2->hasNUsesOrMore(Limit)) + return false; + + auto AllUsersVectorized = [U1, U2, this](Value *V) { + return llvm::all_of(V->users(), [U1, U2, this](Value *U) { + return U == U1 || U == U2 || R.getTreeEntry(U) != nullptr; + }); + }; + return AllUsersVectorized(V1) && AllUsersVectorized(V2); + }; // A broadcast of a load can be cheaper on some targets. - // TODO: For now accept a broadcast load with no other internal uses. - if (TTI->isLegalBroadcastLoad(V1->getType(), NumLanes) && - (int)V1->getNumUses() == NumLanes) + if (R.TTI->isLegalBroadcastLoad(V1->getType(), NumLanes) && + ((int)V1->getNumUses() == NumLanes || + AllUsersAreInternal(V1, V2))) return VLOperands::ScoreSplatLoads; } return VLOperands::ScoreSplat; @@ -1354,12 +1368,13 @@ class BoUpSLP { /// Look-ahead SLP: Auto-vectorization in the presence of commutative /// operations, CGO 2018 by Vasileios Porpodas, Rodrigo C. O. Rocha, /// Luís F. W. Góes - int getScoreAtLevelRec(Value *LHS, Value *RHS, int CurrLevel, int MaxLevel, + int getScoreAtLevelRec(Value *LHS, Value *RHS, Instruction *U1, + Instruction *U2, int CurrLevel, int MaxLevel, ArrayRef MainAltOps) { // Get the shallow score of V1 and V2. int ShallowScoreAtThisLevel = - getShallowScore(LHS, RHS, DL, SE, getNumLanes(), MainAltOps, R.TTI); + getShallowScore(LHS, RHS, U1, U2, DL, SE, getNumLanes(), MainAltOps); // If reached MaxLevel, // or if V1 and V2 are not instructions, @@ -1402,7 +1417,7 @@ class BoUpSLP { // Recursively calculate the cost at each level int TmpScore = getScoreAtLevelRec(I1->getOperand(OpIdx1), I2->getOperand(OpIdx2), - CurrLevel + 1, MaxLevel, None); + I1, I2, CurrLevel + 1, MaxLevel, None); // Look for the best score. if (TmpScore > VLOperands::ScoreFail && TmpScore > MaxTmpScore) { MaxTmpScore = TmpScore; @@ -1432,8 +1447,10 @@ class BoUpSLP { int getLookAheadScore(Value *LHS, Value *RHS, ArrayRef MainAltOps, int Lane, unsigned OpIdx, unsigned Idx, bool &IsUsed) { - int Score = - getScoreAtLevelRec(LHS, RHS, 1, LookAheadMaxDepth, MainAltOps); + // Keep track of the instruction stack as we recurse into the operands + // during the look-ahead score exploration. + int Score = getScoreAtLevelRec(LHS, RHS, /*U1=*/nullptr, /*U2=*/nullptr, + 1, LookAheadMaxDepth, MainAltOps); if (Score) { int SplatScore = getSplatScore(Lane, OpIdx, Idx); if (Score <= -SplatScore) { diff --git a/llvm/test/Transforms/SLPVectorizer/X86/lookahead.ll b/llvm/test/Transforms/SLPVectorizer/X86/lookahead.ll index 02e5b37c3de7cf..36b867d4a148fb 100644 --- a/llvm/test/Transforms/SLPVectorizer/X86/lookahead.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/lookahead.ll @@ -781,28 +781,50 @@ entry: ; Same as splat_loads() but the splat load has internal uses in the slp graph. define double @splat_loads_with_internal_uses(double *%array1, double *%array2, double *%ptrA, double *%ptrB) { -; CHECK-LABEL: @splat_loads_with_internal_uses( -; CHECK-NEXT: entry: -; CHECK-NEXT: [[GEP_1_0:%.*]] = getelementptr inbounds double, double* [[ARRAY1:%.*]], i64 0 -; CHECK-NEXT: [[GEP_2_0:%.*]] = getelementptr inbounds double, double* [[ARRAY2:%.*]], i64 0 -; CHECK-NEXT: [[TMP0:%.*]] = bitcast double* [[GEP_1_0]] to <2 x double>* -; CHECK-NEXT: [[TMP1:%.*]] = load <2 x double>, <2 x double>* [[TMP0]], align 8 -; CHECK-NEXT: [[TMP2:%.*]] = bitcast double* [[GEP_2_0]] to <2 x double>* -; CHECK-NEXT: [[TMP3:%.*]] = load <2 x double>, <2 x double>* [[TMP2]], align 8 -; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <2 x double> [[TMP3]], <2 x double> poison, <2 x i32> -; CHECK-NEXT: [[TMP4:%.*]] = fmul <2 x double> [[TMP1]], [[SHUFFLE]] -; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x double> [[SHUFFLE]], i32 1 -; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x double> poison, double [[TMP5]], i32 0 -; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x double> [[SHUFFLE]], i32 0 -; CHECK-NEXT: [[TMP8:%.*]] = insertelement <2 x double> [[TMP6]], double [[TMP7]], i32 1 -; CHECK-NEXT: [[TMP9:%.*]] = fmul <2 x double> [[TMP1]], [[TMP8]] -; CHECK-NEXT: [[TMP10:%.*]] = fadd <2 x double> [[TMP4]], [[TMP9]] -; CHECK-NEXT: [[TMP11:%.*]] = insertelement <2 x double> [[TMP6]], double [[TMP5]], i32 1 -; CHECK-NEXT: [[TMP12:%.*]] = fsub <2 x double> [[TMP10]], [[TMP11]] -; CHECK-NEXT: [[TMP13:%.*]] = extractelement <2 x double> [[TMP12]], i32 0 -; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[TMP12]], i32 1 -; CHECK-NEXT: [[RES:%.*]] = fadd double [[TMP13]], [[TMP14]] -; CHECK-NEXT: ret double [[RES]] +; SSE-LABEL: @splat_loads_with_internal_uses( +; SSE-NEXT: entry: +; SSE-NEXT: [[GEP_1_0:%.*]] = getelementptr inbounds double, double* [[ARRAY1:%.*]], i64 0 +; SSE-NEXT: [[GEP_2_0:%.*]] = getelementptr inbounds double, double* [[ARRAY2:%.*]], i64 0 +; SSE-NEXT: [[TMP0:%.*]] = bitcast double* [[GEP_1_0]] to <2 x double>* +; SSE-NEXT: [[TMP1:%.*]] = load <2 x double>, <2 x double>* [[TMP0]], align 8 +; SSE-NEXT: [[TMP2:%.*]] = bitcast double* [[GEP_2_0]] to <2 x double>* +; SSE-NEXT: [[TMP3:%.*]] = load <2 x double>, <2 x double>* [[TMP2]], align 8 +; SSE-NEXT: [[SHUFFLE:%.*]] = shufflevector <2 x double> [[TMP3]], <2 x double> poison, <2 x i32> +; SSE-NEXT: [[TMP4:%.*]] = fmul <2 x double> [[TMP1]], [[SHUFFLE]] +; SSE-NEXT: [[TMP5:%.*]] = extractelement <2 x double> [[SHUFFLE]], i32 1 +; SSE-NEXT: [[TMP6:%.*]] = insertelement <2 x double> poison, double [[TMP5]], i32 0 +; SSE-NEXT: [[TMP7:%.*]] = extractelement <2 x double> [[SHUFFLE]], i32 0 +; SSE-NEXT: [[TMP8:%.*]] = insertelement <2 x double> [[TMP6]], double [[TMP7]], i32 1 +; SSE-NEXT: [[TMP9:%.*]] = fmul <2 x double> [[TMP1]], [[TMP8]] +; SSE-NEXT: [[TMP10:%.*]] = fadd <2 x double> [[TMP4]], [[TMP9]] +; SSE-NEXT: [[TMP11:%.*]] = insertelement <2 x double> [[TMP6]], double [[TMP5]], i32 1 +; SSE-NEXT: [[TMP12:%.*]] = fsub <2 x double> [[TMP10]], [[TMP11]] +; SSE-NEXT: [[TMP13:%.*]] = extractelement <2 x double> [[TMP12]], i32 0 +; SSE-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[TMP12]], i32 1 +; SSE-NEXT: [[RES:%.*]] = fadd double [[TMP13]], [[TMP14]] +; SSE-NEXT: ret double [[RES]] +; +; AVX-LABEL: @splat_loads_with_internal_uses( +; AVX-NEXT: entry: +; AVX-NEXT: [[GEP_1_0:%.*]] = getelementptr inbounds double, double* [[ARRAY1:%.*]], i64 0 +; AVX-NEXT: [[GEP_2_0:%.*]] = getelementptr inbounds double, double* [[ARRAY2:%.*]], i64 0 +; AVX-NEXT: [[GEP_2_1:%.*]] = getelementptr inbounds double, double* [[ARRAY2]], i64 1 +; AVX-NEXT: [[LD_2_0:%.*]] = load double, double* [[GEP_2_0]], align 8 +; AVX-NEXT: [[LD_2_1:%.*]] = load double, double* [[GEP_2_1]], align 8 +; AVX-NEXT: [[TMP0:%.*]] = bitcast double* [[GEP_1_0]] to <2 x double>* +; AVX-NEXT: [[TMP1:%.*]] = load <2 x double>, <2 x double>* [[TMP0]], align 8 +; AVX-NEXT: [[TMP2:%.*]] = insertelement <2 x double> poison, double [[LD_2_0]], i32 0 +; AVX-NEXT: [[TMP3:%.*]] = insertelement <2 x double> [[TMP2]], double [[LD_2_0]], i32 1 +; AVX-NEXT: [[TMP4:%.*]] = fmul <2 x double> [[TMP1]], [[TMP3]] +; AVX-NEXT: [[TMP5:%.*]] = insertelement <2 x double> poison, double [[LD_2_1]], i32 0 +; AVX-NEXT: [[TMP6:%.*]] = insertelement <2 x double> [[TMP5]], double [[LD_2_1]], i32 1 +; AVX-NEXT: [[TMP7:%.*]] = fmul <2 x double> [[TMP1]], [[TMP6]] +; AVX-NEXT: [[TMP8:%.*]] = fadd <2 x double> [[TMP4]], [[TMP7]] +; AVX-NEXT: [[TMP9:%.*]] = fsub <2 x double> [[TMP8]], [[TMP3]] +; AVX-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[TMP9]], i32 0 +; AVX-NEXT: [[TMP11:%.*]] = extractelement <2 x double> [[TMP9]], i32 1 +; AVX-NEXT: [[RES:%.*]] = fadd double [[TMP10]], [[TMP11]] +; AVX-NEXT: ret double [[RES]] ; entry: %gep_1_0 = getelementptr inbounds double, double* %array1, i64 0