diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 8dd318a880fcfd..bf8ef208ccf9a8 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -7058,12 +7058,10 @@ class HorizontalReduction { int getReductionCost(TargetTransformInfo *TTI, Value *FirstReducedVal, unsigned ReduxWidth) { Type *ScalarTy = FirstReducedVal->getType(); - auto *VecTy = FixedVectorType::get(ScalarTy, ReduxWidth); + FixedVectorType *VectorTy = FixedVectorType::get(ScalarTy, ReduxWidth); RecurKind Kind = RdxTreeInst.getKind(); - unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind); - int SplittingRdxCost; - int ScalarReduxCost; + int VectorCost, ScalarCost; switch (Kind) { case RecurKind::Add: case RecurKind::Mul: @@ -7071,22 +7069,24 @@ class HorizontalReduction { case RecurKind::And: case RecurKind::Xor: case RecurKind::FAdd: - case RecurKind::FMul: - SplittingRdxCost = TTI->getArithmeticReductionCost( - RdxOpcode, VecTy, /*IsPairwiseForm=*/false); - ScalarReduxCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy); + case RecurKind::FMul: { + unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind); + VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, + /*IsPairwiseForm=*/false); + ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy); break; + } case RecurKind::SMax: case RecurKind::SMin: case RecurKind::UMax: case RecurKind::UMin: { - auto *VecCondTy = cast(CmpInst::makeCmpResultType(VecTy)); + auto *VecCondTy = cast(CmpInst::makeCmpResultType(VectorTy)); bool IsUnsigned = Kind == RecurKind::UMax || Kind == RecurKind::UMin; - SplittingRdxCost = - TTI->getMinMaxReductionCost(VecTy, VecCondTy, + VectorCost = + TTI->getMinMaxReductionCost(VectorTy, VecCondTy, /*IsPairwiseForm=*/false, IsUnsigned); - ScalarReduxCost = - TTI->getCmpSelInstrCost(RdxOpcode, ScalarTy) + + ScalarCost = + TTI->getCmpSelInstrCost(Instruction::ICmp, ScalarTy) + TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, CmpInst::makeCmpResultType(ScalarTy)); break; @@ -7095,12 +7095,12 @@ class HorizontalReduction { llvm_unreachable("Expected arithmetic or min/max reduction operation"); } - ScalarReduxCost *= (ReduxWidth - 1); - LLVM_DEBUG(dbgs() << "SLP: Adding cost " - << SplittingRdxCost - ScalarReduxCost + // Scalar cost is repeated for N-1 elements. + ScalarCost *= (ReduxWidth - 1); + LLVM_DEBUG(dbgs() << "SLP: Adding cost " << VectorCost - ScalarCost << " for reduction that starts with " << *FirstReducedVal << " (It is a splitting reduction)\n"); - return SplittingRdxCost - ScalarReduxCost; + return VectorCost - ScalarCost; } /// Emit a horizontal reduction of the vectorized value.