diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index c132f0c4941c5..73313465adea9 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -847,8 +847,10 @@ class VPInstruction : public VPRecipeBase, public VPValue { const std::string Name; /// Utility method serving execute(): generates a single instance of the - /// modeled instruction. - void generateInstruction(VPTransformState &State, unsigned Part); + /// modeled instruction. \returns the generated value for \p Part. + /// In some cases an existing value is returned rather than a generated + /// one. + Value *generateInstruction(VPTransformState &State, unsigned Part); protected: void setUnderlyingInstr(Instruction *I) { setUnderlyingValue(I); } diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 5a4e8cca39844..26c309eed8003 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -216,41 +216,32 @@ void VPRecipeBase::moveBefore(VPBasicBlock &BB, insertBefore(BB, I); } -void VPInstruction::generateInstruction(VPTransformState &State, - unsigned Part) { +Value *VPInstruction::generateInstruction(VPTransformState &State, + unsigned Part) { IRBuilderBase &Builder = State.Builder; Builder.SetCurrentDebugLocation(DL); if (Instruction::isBinaryOp(getOpcode())) { Value *A = State.get(getOperand(0), Part); Value *B = State.get(getOperand(1), Part); - Value *V = - Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B, Name); - State.set(this, V, Part); - return; + return Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B, Name); } switch (getOpcode()) { case VPInstruction::Not: { Value *A = State.get(getOperand(0), Part); - Value *V = Builder.CreateNot(A, Name); - State.set(this, V, Part); - break; + return Builder.CreateNot(A, Name); } case VPInstruction::ICmpULE: { Value *IV = State.get(getOperand(0), Part); Value *TC = State.get(getOperand(1), Part); - Value *V = Builder.CreateICmpULE(IV, TC, Name); - State.set(this, V, Part); - break; + return Builder.CreateICmpULE(IV, TC, Name); } case Instruction::Select: { Value *Cond = State.get(getOperand(0), Part); Value *Op1 = State.get(getOperand(1), Part); Value *Op2 = State.get(getOperand(2), Part); - Value *V = Builder.CreateSelect(Cond, Op1, Op2, Name); - State.set(this, V, Part); - break; + return Builder.CreateSelect(Cond, Op1, Op2, Name); } case VPInstruction::ActiveLaneMask: { // Get first lane of vector induction variable. @@ -260,11 +251,9 @@ void VPInstruction::generateInstruction(VPTransformState &State, auto *Int1Ty = Type::getInt1Ty(Builder.getContext()); auto *PredTy = VectorType::get(Int1Ty, State.VF); - Instruction *Call = Builder.CreateIntrinsic( - Intrinsic::get_active_lane_mask, {PredTy, ScalarTC->getType()}, - {VIVElem0, ScalarTC}, nullptr, Name); - State.set(this, Call, Part); - break; + return Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, + {PredTy, ScalarTC->getType()}, + {VIVElem0, ScalarTC}, nullptr, Name); } case VPInstruction::FirstOrderRecurrenceSplice: { // Generate code to combine the previous and current values in vector v3. @@ -282,14 +271,10 @@ void VPInstruction::generateInstruction(VPTransformState &State, // For the first part, use the recurrence phi (v1), otherwise v2. auto *V1 = State.get(getOperand(0), 0); Value *PartMinus1 = Part == 0 ? V1 : State.get(getOperand(1), Part - 1); - if (!PartMinus1->getType()->isVectorTy()) { - State.set(this, PartMinus1, Part); - } else { - Value *V2 = State.get(getOperand(1), Part); - State.set(this, Builder.CreateVectorSplice(PartMinus1, V2, -1, Name), - Part); - } - break; + if (!PartMinus1->getType()->isVectorTy()) + return PartMinus1; + Value *V2 = State.get(getOperand(1), Part); + return Builder.CreateVectorSplice(PartMinus1, V2, -1, Name); } case VPInstruction::CalculateTripCountMinusVF: { Value *ScalarTC = State.get(getOperand(0), {0, 0}); @@ -298,13 +283,10 @@ void VPInstruction::generateInstruction(VPTransformState &State, Value *Sub = Builder.CreateSub(ScalarTC, Step); Value *Cmp = Builder.CreateICmp(CmpInst::Predicate::ICMP_UGT, ScalarTC, Step); Value *Zero = ConstantInt::get(ScalarTC->getType(), 0); - Value *Sel = Builder.CreateSelect(Cmp, Sub, Zero); - State.set(this, Sel, Part); - break; + return Builder.CreateSelect(Cmp, Sub, Zero); } case VPInstruction::CanonicalIVIncrement: case VPInstruction::CanonicalIVIncrementNUW: { - Value *Next = nullptr; if (Part == 0) { bool IsNUW = getOpcode() == VPInstruction::CanonicalIVIncrementNUW; auto *Phi = State.get(getOperand(0), 0); @@ -312,34 +294,26 @@ void VPInstruction::generateInstruction(VPTransformState &State, // elements) times the unroll factor (num of SIMD instructions). Value *Step = createStepForVF(Builder, Phi->getType(), State.VF, State.UF); - Next = Builder.CreateAdd(Phi, Step, Name, IsNUW, false); - } else { - Next = State.get(this, 0); + return Builder.CreateAdd(Phi, Step, Name, IsNUW, false); } - - State.set(this, Next, Part); - break; + return State.get(this, 0); } case VPInstruction::CanonicalIVIncrementForPart: case VPInstruction::CanonicalIVIncrementForPartNUW: { bool IsNUW = getOpcode() == VPInstruction::CanonicalIVIncrementForPartNUW; auto *IV = State.get(getOperand(0), VPIteration(0, 0)); - if (Part == 0) { - State.set(this, IV, Part); - break; - } + if (Part == 0) + return IV; // The canonical IV is incremented by the vectorization factor (num of SIMD // elements) times the unroll part. Value *Step = createStepForVF(Builder, IV->getType(), State.VF, Part); - Value *Next = Builder.CreateAdd(IV, Step, Name, IsNUW, false); - State.set(this, Next, Part); - break; + return Builder.CreateAdd(IV, Step, Name, IsNUW, false); } case VPInstruction::BranchOnCond: { if (Part != 0) - break; + return nullptr; Value *Cond = State.get(getOperand(0), VPIteration(Part, 0)); VPRegionBlock *ParentRegion = getParent()->getParent(); @@ -356,11 +330,11 @@ void VPInstruction::generateInstruction(VPTransformState &State, CondBr->setSuccessor(0, nullptr); Builder.GetInsertBlock()->getTerminator()->eraseFromParent(); - break; + return CondBr; } case VPInstruction::BranchOnCount: { if (Part != 0) - break; + return nullptr; // First create the compare. Value *IV = State.get(getOperand(0), Part); Value *TC = State.get(getOperand(1), Part); @@ -380,7 +354,7 @@ void VPInstruction::generateInstruction(VPTransformState &State, State.CFG.VPBB2IRBB[Header]); CondBr->setSuccessor(0, nullptr); Builder.GetInsertBlock()->getTerminator()->eraseFromParent(); - break; + return CondBr; } default: llvm_unreachable("Unsupported opcode for instruction"); @@ -391,8 +365,13 @@ void VPInstruction::execute(VPTransformState &State) { assert(!State.Instance && "VPInstruction executing an Instance"); IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder); State.Builder.setFastMathFlags(FMF); - for (unsigned Part = 0; Part < State.UF; ++Part) - generateInstruction(State, Part); + for (unsigned Part = 0; Part < State.UF; ++Part) { + Value *GeneratedValue = generateInstruction(State, Part); + if (!hasResult()) + continue; + assert(GeneratedValue && "generateInstruction must produce a value"); + State.set(this, GeneratedValue, Part); + } } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)