Skip to content

Conversation

@SamTebbs33
Copy link
Collaborator

@SamTebbs33 SamTebbs33 commented Oct 29, 2025

This PR re-introduces the assert that the cost of a partial reduction is valid during VPExpressionRecipe creation.

This is a stacked PR:

  1. [LV] Allow partial reductions with an extended bin op #165536
  2. -> [LV] Use assertion in VPExpressionRecipe creation #165543

@llvmbot
Copy link
Member

llvmbot commented Oct 29, 2025

@llvm/pr-subscribers-vectorizers

Author: Sam Tebbs (SamTebbs33)

Changes

This PR re-introduces the assert that the cost of a partial reduction is valid during VPExpressionRecipe creation.

This is a stacked PR:

  1. [LV] Allow partial reductions with an extended bin op #165536
  2. -> [LV] Use assertion in VPExpressionRecipe creation #165543

Full diff: https://github.com/llvm/llvm-project/pull/165543.diff

1 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+38-32)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index d9ac26bba7507..e75c99c35938e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -3532,24 +3532,28 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
           auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
           TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
 
-          InstructionCost ExtRedCost;
-          InstructionCost ExtCost =
-              cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
-          InstructionCost RedCost = Red->computeCost(VF, Ctx);
-
           if (isa<VPPartialReductionRecipe>(Red)) {
             TargetTransformInfo::PartialReductionExtendKind ExtKind =
                 TargetTransformInfo::getPartialReductionExtendKind(ExtOpc);
             // FIXME: Move partial reduction creation, costing and clamping
             // here from LoopVectorize.cpp.
-            ExtRedCost = Ctx.TTI.getPartialReductionCost(
-                Opcode, SrcTy, nullptr, RedTy, VF, ExtKind,
-                llvm::TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind);
-          } else {
-            ExtRedCost = Ctx.TTI.getExtendedReductionCost(
-                Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
-                Red->getFastMathFlags(), CostKind);
+            InstructionCost PartialReductionCost =
+                Ctx.TTI.getPartialReductionCost(
+                    Opcode, SrcTy, nullptr, RedTy, VF, ExtKind,
+                    llvm::TargetTransformInfo::PR_None, std::nullopt,
+                    Ctx.CostKind);
+            assert(PartialReductionCost.isValid() &&
+                   "A partial reduction should have a valid cost");
+            return true;
           }
+
+          InstructionCost ExtCost =
+              cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
+          InstructionCost RedCost = Red->computeCost(VF, Ctx);
+
+          InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost(
+              Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
+              Red->getFastMathFlags(), CostKind);
           return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost;
         },
         Range);
@@ -3595,33 +3599,35 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
           TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
           Type *SrcTy =
               Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
-          InstructionCost MulAccCost;
 
           if (IsPartialReduction) {
             Type *SrcTy2 =
                 Ext1 ? Ctx.Types.inferScalarType(Ext1->getOperand(0)) : nullptr;
             // FIXME: Move partial reduction creation, costing and clamping
             // here from LoopVectorize.cpp.
-            MulAccCost = Ctx.TTI.getPartialReductionCost(
-                Opcode, SrcTy, SrcTy2, RedTy, VF,
-                Ext0 ? TargetTransformInfo::getPartialReductionExtendKind(
-                           Ext0->getOpcode())
-                     : TargetTransformInfo::PR_None,
-                Ext1 ? TargetTransformInfo::getPartialReductionExtendKind(
-                           Ext1->getOpcode())
-                     : TargetTransformInfo::PR_None,
-                Mul->getOpcode(), CostKind);
-          } else {
-            // Only partial reductions support mixed extends at the moment.
-            if (Ext0 && Ext1 && Ext0->getOpcode() != Ext1->getOpcode())
-              return false;
-
-            bool IsZExt =
-                !Ext0 || Ext0->getOpcode() == Instruction::CastOps::ZExt;
-            auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
-            MulAccCost = Ctx.TTI.getMulAccReductionCost(IsZExt, Opcode, RedTy,
-                                                        SrcVecTy, CostKind);
+            InstructionCost PartialReductionCost =
+                Ctx.TTI.getPartialReductionCost(
+                    Opcode, SrcTy, SrcTy2, RedTy, VF,
+                    Ext0 ? TargetTransformInfo::getPartialReductionExtendKind(
+                               Ext0->getOpcode())
+                         : TargetTransformInfo::PR_None,
+                    Ext1 ? TargetTransformInfo::getPartialReductionExtendKind(
+                               Ext1->getOpcode())
+                         : TargetTransformInfo::PR_None,
+                    Mul->getOpcode(), CostKind);
+            assert(PartialReductionCost.isValid() &&
+                   "A partial reduction should have a valid cost");
+            return true;
           }
+          // Only partial reductions support mixed extends at the moment.
+          if (Ext0 && Ext1 && Ext0->getOpcode() != Ext1->getOpcode())
+            return false;
+
+          bool IsZExt =
+              !Ext0 || Ext0->getOpcode() == Instruction::CastOps::ZExt;
+          auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
+          InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
+              IsZExt, Opcode, RedTy, SrcVecTy, CostKind);
 
           InstructionCost MulCost = Mul->computeCost(VF, Ctx);
           InstructionCost RedCost = Red->computeCost(VF, Ctx);

@llvmbot
Copy link
Member

llvmbot commented Oct 29, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Sam Tebbs (SamTebbs33)

Changes

This PR re-introduces the assert that the cost of a partial reduction is valid during VPExpressionRecipe creation.

This is a stacked PR:

  1. [LV] Allow partial reductions with an extended bin op #165536
  2. -> [LV] Use assertion in VPExpressionRecipe creation #165543

Full diff: https://github.com/llvm/llvm-project/pull/165543.diff

1 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+38-32)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index d9ac26bba7507..e75c99c35938e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -3532,24 +3532,28 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
           auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
           TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
 
-          InstructionCost ExtRedCost;
-          InstructionCost ExtCost =
-              cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
-          InstructionCost RedCost = Red->computeCost(VF, Ctx);
-
           if (isa<VPPartialReductionRecipe>(Red)) {
             TargetTransformInfo::PartialReductionExtendKind ExtKind =
                 TargetTransformInfo::getPartialReductionExtendKind(ExtOpc);
             // FIXME: Move partial reduction creation, costing and clamping
             // here from LoopVectorize.cpp.
-            ExtRedCost = Ctx.TTI.getPartialReductionCost(
-                Opcode, SrcTy, nullptr, RedTy, VF, ExtKind,
-                llvm::TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind);
-          } else {
-            ExtRedCost = Ctx.TTI.getExtendedReductionCost(
-                Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
-                Red->getFastMathFlags(), CostKind);
+            InstructionCost PartialReductionCost =
+                Ctx.TTI.getPartialReductionCost(
+                    Opcode, SrcTy, nullptr, RedTy, VF, ExtKind,
+                    llvm::TargetTransformInfo::PR_None, std::nullopt,
+                    Ctx.CostKind);
+            assert(PartialReductionCost.isValid() &&
+                   "A partial reduction should have a valid cost");
+            return true;
           }
+
+          InstructionCost ExtCost =
+              cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
+          InstructionCost RedCost = Red->computeCost(VF, Ctx);
+
+          InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost(
+              Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
+              Red->getFastMathFlags(), CostKind);
           return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost;
         },
         Range);
@@ -3595,33 +3599,35 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
           TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
           Type *SrcTy =
               Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
-          InstructionCost MulAccCost;
 
           if (IsPartialReduction) {
             Type *SrcTy2 =
                 Ext1 ? Ctx.Types.inferScalarType(Ext1->getOperand(0)) : nullptr;
             // FIXME: Move partial reduction creation, costing and clamping
             // here from LoopVectorize.cpp.
-            MulAccCost = Ctx.TTI.getPartialReductionCost(
-                Opcode, SrcTy, SrcTy2, RedTy, VF,
-                Ext0 ? TargetTransformInfo::getPartialReductionExtendKind(
-                           Ext0->getOpcode())
-                     : TargetTransformInfo::PR_None,
-                Ext1 ? TargetTransformInfo::getPartialReductionExtendKind(
-                           Ext1->getOpcode())
-                     : TargetTransformInfo::PR_None,
-                Mul->getOpcode(), CostKind);
-          } else {
-            // Only partial reductions support mixed extends at the moment.
-            if (Ext0 && Ext1 && Ext0->getOpcode() != Ext1->getOpcode())
-              return false;
-
-            bool IsZExt =
-                !Ext0 || Ext0->getOpcode() == Instruction::CastOps::ZExt;
-            auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
-            MulAccCost = Ctx.TTI.getMulAccReductionCost(IsZExt, Opcode, RedTy,
-                                                        SrcVecTy, CostKind);
+            InstructionCost PartialReductionCost =
+                Ctx.TTI.getPartialReductionCost(
+                    Opcode, SrcTy, SrcTy2, RedTy, VF,
+                    Ext0 ? TargetTransformInfo::getPartialReductionExtendKind(
+                               Ext0->getOpcode())
+                         : TargetTransformInfo::PR_None,
+                    Ext1 ? TargetTransformInfo::getPartialReductionExtendKind(
+                               Ext1->getOpcode())
+                         : TargetTransformInfo::PR_None,
+                    Mul->getOpcode(), CostKind);
+            assert(PartialReductionCost.isValid() &&
+                   "A partial reduction should have a valid cost");
+            return true;
           }
+          // Only partial reductions support mixed extends at the moment.
+          if (Ext0 && Ext1 && Ext0->getOpcode() != Ext1->getOpcode())
+            return false;
+
+          bool IsZExt =
+              !Ext0 || Ext0->getOpcode() == Instruction::CastOps::ZExt;
+          auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
+          InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
+              IsZExt, Opcode, RedTy, SrcVecTy, CostKind);
 
           InstructionCost MulCost = Mul->computeCost(VF, Ctx);
           InstructionCost RedCost = Red->computeCost(VF, Ctx);

@SamTebbs33
Copy link
Collaborator Author

Not needed as we'll be moving towards creating partial reductions during the VPExpressionRecipe creation process.

@SamTebbs33 SamTebbs33 closed this Nov 19, 2025
SamTebbs33 added a commit that referenced this pull request Nov 20, 2025
A pattern of the form reduce.add(ext(mul)) is valid for a partial
reduction as long as the mul and its operands fulfill the requirements
of a normal partial reduction. The mul's extend operands will be
optimised to the wider extend, and we already have oneUse checks in
place to make sure the mul and operands can be modified safely.

1. -> #165536
2. #165543
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Nov 20, 2025
…5536)

A pattern of the form reduce.add(ext(mul)) is valid for a partial
reduction as long as the mul and its operands fulfill the requirements
of a normal partial reduction. The mul's extend operands will be
optimised to the wider extend, and we already have oneUse checks in
place to make sure the mul and operands can be modified safely.

1. -> llvm/llvm-project#165536
2. llvm/llvm-project#165543
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants