diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index e0e6990c56ec7..8a67ea0045755 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -308,94 +308,11 @@ bool VPRecipeBase::isScalarCast() const { InstructionCost VPPartialReductionRecipe::computeCost(ElementCount VF, VPCostContext &Ctx) const { - std::optional Opcode; - VPValue *Op = getVecOp(); - uint64_t MulConst; - - InstructionCost CondCost = 0; - if (isConditional()) { - CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; - auto *VecTy = Ctx.Types.inferScalarType(Op); - auto *CondTy = Ctx.Types.inferScalarType(getCondOp()); - CondCost = Ctx.TTI.getCmpSelInstrCost(Instruction::Select, VecTy, CondTy, - Pred, Ctx.CostKind); - } - - // If the partial reduction is predicated, a select will be operand 1. - // If it isn't predicated and the mul isn't operating on a constant, then it - // should have been turned into a VPExpressionRecipe. - // FIXME: Replace the entire function with this once all partial reduction - // variants are bundled into VPExpressionRecipe. - if (!match(Op, m_Mul(m_VPValue(), m_ConstantInt(MulConst)))) { - auto *PhiType = Ctx.Types.inferScalarType(getChainOp()); - auto *InputType = Ctx.Types.inferScalarType(getVecOp()); - return CondCost + Ctx.TTI.getPartialReductionCost( - getOpcode(), InputType, InputType, PhiType, VF, - TTI::PR_None, TTI::PR_None, {}, Ctx.CostKind); - } - - VPRecipeBase *OpR = Op->getDefiningRecipe(); - Type *InputTypeA = nullptr, *InputTypeB = nullptr; - TTI::PartialReductionExtendKind ExtAType = TTI::PR_None, - ExtBType = TTI::PR_None; - - auto GetExtendKind = [](VPRecipeBase *R) { - if (!R) - return TTI::PR_None; - auto *WidenCastR = dyn_cast(R); - if (!WidenCastR) - return TTI::PR_None; - if (WidenCastR->getOpcode() == Instruction::CastOps::ZExt) - return TTI::PR_ZeroExtend; - if (WidenCastR->getOpcode() == Instruction::CastOps::SExt) - return TTI::PR_SignExtend; - return TTI::PR_None; - }; - - // Pick out opcode, type/ext information and use sub side effects from a widen - // recipe. - auto HandleWiden = [&](VPWidenRecipe *Widen) { - if (match(Widen, m_Sub(m_ZeroInt(), m_VPValue(Op)))) { - Widen = dyn_cast(Op); - } - Opcode = Widen->getOpcode(); - VPRecipeBase *ExtAR = Widen->getOperand(0)->getDefiningRecipe(); - VPRecipeBase *ExtBR = Widen->getOperand(1)->getDefiningRecipe(); - InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0) - : Widen->getOperand(0)); - InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0) - : Widen->getOperand(1)); - ExtAType = GetExtendKind(ExtAR); - ExtBType = GetExtendKind(ExtBR); - - using namespace VPlanPatternMatch; - const APInt *C; - if (!ExtBR && match(Widen->getOperand(1), m_APInt(C)) && - canConstantBeExtended(C, InputTypeA, ExtAType)) { - InputTypeB = InputTypeA; - ExtBType = ExtAType; - } - }; - - if (isa(OpR)) { - InputTypeA = Ctx.Types.inferScalarType(OpR->getOperand(0)); - ExtAType = GetExtendKind(OpR); - } else if (isa(OpR)) { - if (auto RedPhiOp1R = dyn_cast_or_null(getOperand(1))) { - InputTypeA = Ctx.Types.inferScalarType(RedPhiOp1R->getOperand(0)); - ExtAType = GetExtendKind(RedPhiOp1R); - } else if (auto Widen = dyn_cast_or_null(getOperand(1))) - HandleWiden(Widen); - } else if (auto Widen = dyn_cast(OpR)) { - HandleWiden(Widen); - } else if (auto Reduction = dyn_cast(OpR)) { - return CondCost + Reduction->computeCost(VF, Ctx); - } - auto *PhiType = Ctx.Types.inferScalarType(getOperand(1)); - return CondCost + Ctx.TTI.getPartialReductionCost( - getOpcode(), InputTypeA, InputTypeB, PhiType, VF, - ExtAType, ExtBType, Opcode, Ctx.CostKind); - ; + auto *PhiType = Ctx.Types.inferScalarType(getChainOp()); + auto *InputType = Ctx.Types.inferScalarType(getVecOp()); + return Ctx.TTI.getPartialReductionCost(getOpcode(), InputType, InputType, + PhiType, VF, TTI::PR_None, + TTI::PR_None, {}, Ctx.CostKind); } void VPPartialReductionRecipe::execute(VPTransformState &State) {