Skip to content

Commit

Permalink
[LV] Update generateInstruction to return produced value (NFC).
Browse files Browse the repository at this point in the history
Update generateInstruction to return the produced value instead of
setting it for each opcode. This reduces the amount of duplicated code
and is a preparation for D153696.

Reviewed By: Ayal

Differential Revision: https://reviews.llvm.org/D154240
  • Loading branch information
fhahn committed Jul 5, 2023
1 parent 1039aec commit 2265bb0
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 53 deletions.
6 changes: 4 additions & 2 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -847,8 +847,10 @@ class VPInstruction : public VPRecipeBase, public VPValue {
const std::string Name;

/// Utility method serving execute(): generates a single instance of the
/// modeled instruction.
void generateInstruction(VPTransformState &State, unsigned Part);
/// modeled instruction. \returns the generated value for \p Part.
/// In some cases an existing value is returned rather than a generated
/// one.
Value *generateInstruction(VPTransformState &State, unsigned Part);

protected:
void setUnderlyingInstr(Instruction *I) { setUnderlyingValue(I); }
Expand Down
81 changes: 30 additions & 51 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,41 +216,32 @@ void VPRecipeBase::moveBefore(VPBasicBlock &BB,
insertBefore(BB, I);
}

void VPInstruction::generateInstruction(VPTransformState &State,
unsigned Part) {
Value *VPInstruction::generateInstruction(VPTransformState &State,
unsigned Part) {
IRBuilderBase &Builder = State.Builder;
Builder.SetCurrentDebugLocation(DL);

if (Instruction::isBinaryOp(getOpcode())) {
Value *A = State.get(getOperand(0), Part);
Value *B = State.get(getOperand(1), Part);
Value *V =
Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B, Name);
State.set(this, V, Part);
return;
return Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B, Name);
}

switch (getOpcode()) {
case VPInstruction::Not: {
Value *A = State.get(getOperand(0), Part);
Value *V = Builder.CreateNot(A, Name);
State.set(this, V, Part);
break;
return Builder.CreateNot(A, Name);
}
case VPInstruction::ICmpULE: {
Value *IV = State.get(getOperand(0), Part);
Value *TC = State.get(getOperand(1), Part);
Value *V = Builder.CreateICmpULE(IV, TC, Name);
State.set(this, V, Part);
break;
return Builder.CreateICmpULE(IV, TC, Name);
}
case Instruction::Select: {
Value *Cond = State.get(getOperand(0), Part);
Value *Op1 = State.get(getOperand(1), Part);
Value *Op2 = State.get(getOperand(2), Part);
Value *V = Builder.CreateSelect(Cond, Op1, Op2, Name);
State.set(this, V, Part);
break;
return Builder.CreateSelect(Cond, Op1, Op2, Name);
}
case VPInstruction::ActiveLaneMask: {
// Get first lane of vector induction variable.
Expand All @@ -260,11 +251,9 @@ void VPInstruction::generateInstruction(VPTransformState &State,

auto *Int1Ty = Type::getInt1Ty(Builder.getContext());
auto *PredTy = VectorType::get(Int1Ty, State.VF);
Instruction *Call = Builder.CreateIntrinsic(
Intrinsic::get_active_lane_mask, {PredTy, ScalarTC->getType()},
{VIVElem0, ScalarTC}, nullptr, Name);
State.set(this, Call, Part);
break;
return Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
{PredTy, ScalarTC->getType()},
{VIVElem0, ScalarTC}, nullptr, Name);
}
case VPInstruction::FirstOrderRecurrenceSplice: {
// Generate code to combine the previous and current values in vector v3.
Expand All @@ -282,14 +271,10 @@ void VPInstruction::generateInstruction(VPTransformState &State,
// For the first part, use the recurrence phi (v1), otherwise v2.
auto *V1 = State.get(getOperand(0), 0);
Value *PartMinus1 = Part == 0 ? V1 : State.get(getOperand(1), Part - 1);
if (!PartMinus1->getType()->isVectorTy()) {
State.set(this, PartMinus1, Part);
} else {
Value *V2 = State.get(getOperand(1), Part);
State.set(this, Builder.CreateVectorSplice(PartMinus1, V2, -1, Name),
Part);
}
break;
if (!PartMinus1->getType()->isVectorTy())
return PartMinus1;
Value *V2 = State.get(getOperand(1), Part);
return Builder.CreateVectorSplice(PartMinus1, V2, -1, Name);
}
case VPInstruction::CalculateTripCountMinusVF: {
Value *ScalarTC = State.get(getOperand(0), {0, 0});
Expand All @@ -298,48 +283,37 @@ void VPInstruction::generateInstruction(VPTransformState &State,
Value *Sub = Builder.CreateSub(ScalarTC, Step);
Value *Cmp = Builder.CreateICmp(CmpInst::Predicate::ICMP_UGT, ScalarTC, Step);
Value *Zero = ConstantInt::get(ScalarTC->getType(), 0);
Value *Sel = Builder.CreateSelect(Cmp, Sub, Zero);
State.set(this, Sel, Part);
break;
return Builder.CreateSelect(Cmp, Sub, Zero);
}
case VPInstruction::CanonicalIVIncrement:
case VPInstruction::CanonicalIVIncrementNUW: {
Value *Next = nullptr;
if (Part == 0) {
bool IsNUW = getOpcode() == VPInstruction::CanonicalIVIncrementNUW;
auto *Phi = State.get(getOperand(0), 0);
// The loop step is equal to the vectorization factor (num of SIMD
// elements) times the unroll factor (num of SIMD instructions).
Value *Step =
createStepForVF(Builder, Phi->getType(), State.VF, State.UF);
Next = Builder.CreateAdd(Phi, Step, Name, IsNUW, false);
} else {
Next = State.get(this, 0);
return Builder.CreateAdd(Phi, Step, Name, IsNUW, false);
}

State.set(this, Next, Part);
break;
return State.get(this, 0);
}

case VPInstruction::CanonicalIVIncrementForPart:
case VPInstruction::CanonicalIVIncrementForPartNUW: {
bool IsNUW = getOpcode() == VPInstruction::CanonicalIVIncrementForPartNUW;
auto *IV = State.get(getOperand(0), VPIteration(0, 0));
if (Part == 0) {
State.set(this, IV, Part);
break;
}
if (Part == 0)
return IV;

// The canonical IV is incremented by the vectorization factor (num of SIMD
// elements) times the unroll part.
Value *Step = createStepForVF(Builder, IV->getType(), State.VF, Part);
Value *Next = Builder.CreateAdd(IV, Step, Name, IsNUW, false);
State.set(this, Next, Part);
break;
return Builder.CreateAdd(IV, Step, Name, IsNUW, false);
}
case VPInstruction::BranchOnCond: {
if (Part != 0)
break;
return nullptr;

Value *Cond = State.get(getOperand(0), VPIteration(Part, 0));
VPRegionBlock *ParentRegion = getParent()->getParent();
Expand All @@ -356,11 +330,11 @@ void VPInstruction::generateInstruction(VPTransformState &State,

CondBr->setSuccessor(0, nullptr);
Builder.GetInsertBlock()->getTerminator()->eraseFromParent();
break;
return CondBr;
}
case VPInstruction::BranchOnCount: {
if (Part != 0)
break;
return nullptr;
// First create the compare.
Value *IV = State.get(getOperand(0), Part);
Value *TC = State.get(getOperand(1), Part);
Expand All @@ -380,7 +354,7 @@ void VPInstruction::generateInstruction(VPTransformState &State,
State.CFG.VPBB2IRBB[Header]);
CondBr->setSuccessor(0, nullptr);
Builder.GetInsertBlock()->getTerminator()->eraseFromParent();
break;
return CondBr;
}
default:
llvm_unreachable("Unsupported opcode for instruction");
Expand All @@ -391,8 +365,13 @@ void VPInstruction::execute(VPTransformState &State) {
assert(!State.Instance && "VPInstruction executing an Instance");
IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
State.Builder.setFastMathFlags(FMF);
for (unsigned Part = 0; Part < State.UF; ++Part)
generateInstruction(State, Part);
for (unsigned Part = 0; Part < State.UF; ++Part) {
Value *GeneratedValue = generateInstruction(State, Part);
if (!hasResult())
continue;
assert(GeneratedValue && "generateInstruction must produce a value");
State.set(this, GeneratedValue, Part);
}
}

#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
Expand Down

0 comments on commit 2265bb0

Please sign in to comment.