Skip to content

Commit

Permalink
[LV] Optimize trip count SCEV.
Browse files Browse the repository at this point in the history
To calculate the trip count we need to add 1 to the backedge
taken count. If we need to widen the backedge count, it's better
to do the add before the widening if we can guarantee it won't
overflow.

The code here is based on similar code I found in
LoopIdiomRecognize.

This is the vectorizer version of this InstCombine patch D142783.
Looking at the IR diffs, this does look like it gets more cases
than the InstCombine patch.

Reviewed By: reames

Differential Revision: https://reviews.llvm.org/D147355
  • Loading branch information
topperc committed Apr 12, 2023
1 parent 2326480 commit 4b47d87
Show file tree
Hide file tree
Showing 6 changed files with 310 additions and 309 deletions.
24 changes: 20 additions & 4 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -983,19 +983,35 @@ Value *getRuntimeVF(IRBuilderBase &B, Type *Ty, ElementCount VF) {
return B.CreateElementCount(Ty, VF);
}

const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE) {
const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE,
Loop *OrigLoop) {
const SCEV *BackedgeTakenCount = PSE.getBackedgeTakenCount();
assert(!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && "Invalid loop count");

ScalarEvolution &SE = *PSE.getSE();

unsigned BackEdgeSize = SE.getTypeSizeInBits(BackedgeTakenCount->getType());
unsigned IdxSize = IdxTy->getPrimitiveSizeInBits();

// If we need to need to zero extend the backedge count, check if we can
// add one to it prior to zero extending without overflow. Provided this is
// safe, it allows better simplification of the +1.
if (OrigLoop && BackEdgeSize < IdxSize &&
SE.isLoopEntryGuardedByCond(
OrigLoop, ICmpInst::ICMP_NE, BackedgeTakenCount,
SE.getMinusOne(BackedgeTakenCount->getType()))) {
return SE.getZeroExtendExpr(
SE.getAddExpr(BackedgeTakenCount,
SE.getOne(BackedgeTakenCount->getType())),
IdxTy);
}

// The exit count might have the type of i64 while the phi is i32. This can
// happen if we have an induction variable that is sign extended before the
// compare. The only way that we get a backedge taken count is that the
// induction variable was signed and as such will not overflow. In such a case
// truncation is legal.
if (SE.getTypeSizeInBits(BackedgeTakenCount->getType()) >
IdxTy->getPrimitiveSizeInBits())
if (BackEdgeSize > IdxSize)
BackedgeTakenCount = SE.getTruncateOrNoop(BackedgeTakenCount, IdxTy);
BackedgeTakenCount = SE.getNoopOrZeroExtend(BackedgeTakenCount, IdxTy);

Expand Down Expand Up @@ -2892,7 +2908,7 @@ Value *InnerLoopVectorizer::getOrCreateTripCount(BasicBlock *InsertBlock) {
// Find the loop boundaries.
Type *IdxTy = Legal->getWidestInductionType();
assert(IdxTy && "No type for induction");
const SCEV *ExitCount = createTripCountSCEV(IdxTy, PSE);
const SCEV *ExitCount = createTripCountSCEV(IdxTy, PSE, OrigLoop);

const DataLayout &DL = InsertBlock->getModule()->getDataLayout();

Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ Value *getRuntimeVF(IRBuilderBase &B, Type *Ty, ElementCount VF);
Value *createStepForVF(IRBuilderBase &B, Type *Ty, ElementCount VF,
int64_t Step);

const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE);
const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE,
Loop *CurLoop = nullptr);

/// A range of powers-of-2 vectorization factors with fixed start and
/// adjustable end. The range includes start and excludes end, e.g.,:
Expand Down
84 changes: 40 additions & 44 deletions llvm/test/Transforms/LoopVectorize/RISCV/riscv-unroll.ll
Original file line number Diff line number Diff line change
Expand Up @@ -11,44 +11,42 @@ define ptr @array_add(ptr noalias nocapture readonly %a, ptr noalias nocapture r
; LMUL1-NEXT: [[CMP10:%.*]] = icmp sgt i32 [[SIZE:%.*]], 0
; LMUL1-NEXT: br i1 [[CMP10]], label [[FOR_BODY_PREHEADER:%.*]], label [[FOR_END:%.*]]
; LMUL1: for.body.preheader:
; LMUL1-NEXT: [[TMP0:%.*]] = add i32 [[SIZE]], -1
; LMUL1-NEXT: [[TMP1:%.*]] = zext i32 [[TMP0]] to i64
; LMUL1-NEXT: [[TMP2:%.*]] = add nuw nsw i64 [[TMP1]], 1
; LMUL1-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP2]], 8
; LMUL1-NEXT: [[TMP0:%.*]] = zext i32 [[SIZE]] to i64
; LMUL1-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], 8
; LMUL1-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; LMUL1: vector.ph:
; LMUL1-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP2]], 8
; LMUL1-NEXT: [[N_VEC:%.*]] = sub i64 [[TMP2]], [[N_MOD_VF]]
; LMUL1-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP0]], 8
; LMUL1-NEXT: [[N_VEC:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]]
; LMUL1-NEXT: br label [[VECTOR_BODY:%.*]]
; LMUL1: vector.body:
; LMUL1-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; LMUL1-NEXT: [[TMP3:%.*]] = add i64 [[INDEX]], 0
; LMUL1-NEXT: [[TMP4:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[TMP3]]
; LMUL1-NEXT: [[TMP1:%.*]] = add i64 [[INDEX]], 0
; LMUL1-NEXT: [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[TMP1]]
; LMUL1-NEXT: [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[TMP2]], i32 0
; LMUL1-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i32>, ptr [[TMP3]], align 4
; LMUL1-NEXT: [[TMP4:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], i64 [[TMP1]]
; LMUL1-NEXT: [[TMP5:%.*]] = getelementptr inbounds i32, ptr [[TMP4]], i32 0
; LMUL1-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i32>, ptr [[TMP5]], align 4
; LMUL1-NEXT: [[TMP6:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], i64 [[TMP3]]
; LMUL1-NEXT: [[TMP7:%.*]] = getelementptr inbounds i32, ptr [[TMP6]], i32 0
; LMUL1-NEXT: [[WIDE_LOAD1:%.*]] = load <8 x i32>, ptr [[TMP7]], align 4
; LMUL1-NEXT: [[TMP8:%.*]] = add nsw <8 x i32> [[WIDE_LOAD1]], [[WIDE_LOAD]]
; LMUL1-NEXT: [[TMP9:%.*]] = getelementptr inbounds i32, ptr [[C:%.*]], i64 [[TMP3]]
; LMUL1-NEXT: [[TMP10:%.*]] = getelementptr inbounds i32, ptr [[TMP9]], i32 0
; LMUL1-NEXT: store <8 x i32> [[TMP8]], ptr [[TMP10]], align 4
; LMUL1-NEXT: [[WIDE_LOAD1:%.*]] = load <8 x i32>, ptr [[TMP5]], align 4
; LMUL1-NEXT: [[TMP6:%.*]] = add nsw <8 x i32> [[WIDE_LOAD1]], [[WIDE_LOAD]]
; LMUL1-NEXT: [[TMP7:%.*]] = getelementptr inbounds i32, ptr [[C:%.*]], i64 [[TMP1]]
; LMUL1-NEXT: [[TMP8:%.*]] = getelementptr inbounds i32, ptr [[TMP7]], i32 0
; LMUL1-NEXT: store <8 x i32> [[TMP6]], ptr [[TMP8]], align 4
; LMUL1-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
; LMUL1-NEXT: [[TMP11:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; LMUL1-NEXT: br i1 [[TMP11]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; LMUL1-NEXT: [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; LMUL1-NEXT: br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; LMUL1: middle.block:
; LMUL1-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP2]], [[N_VEC]]
; LMUL1-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP0]], [[N_VEC]]
; LMUL1-NEXT: br i1 [[CMP_N]], label [[FOR_END_LOOPEXIT:%.*]], label [[SCALAR_PH]]
; LMUL1: scalar.ph:
; LMUL1-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[FOR_BODY_PREHEADER]] ]
; LMUL1-NEXT: br label [[FOR_BODY:%.*]]
; LMUL1: for.body:
; LMUL1-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
; LMUL1-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[INDVARS_IV]]
; LMUL1-NEXT: [[TMP12:%.*]] = load i32, ptr [[ARRAYIDX]], align 4
; LMUL1-NEXT: [[TMP10:%.*]] = load i32, ptr [[ARRAYIDX]], align 4
; LMUL1-NEXT: [[ARRAYIDX2:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[INDVARS_IV]]
; LMUL1-NEXT: [[TMP13:%.*]] = load i32, ptr [[ARRAYIDX2]], align 4
; LMUL1-NEXT: [[ADD:%.*]] = add nsw i32 [[TMP13]], [[TMP12]]
; LMUL1-NEXT: [[TMP11:%.*]] = load i32, ptr [[ARRAYIDX2]], align 4
; LMUL1-NEXT: [[ADD:%.*]] = add nsw i32 [[TMP11]], [[TMP10]]
; LMUL1-NEXT: [[ARRAYIDX4:%.*]] = getelementptr inbounds i32, ptr [[C]], i64 [[INDVARS_IV]]
; LMUL1-NEXT: store i32 [[ADD]], ptr [[ARRAYIDX4]], align 4
; LMUL1-NEXT: [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 1
Expand All @@ -65,44 +63,42 @@ define ptr @array_add(ptr noalias nocapture readonly %a, ptr noalias nocapture r
; LMUL2-NEXT: [[CMP10:%.*]] = icmp sgt i32 [[SIZE:%.*]], 0
; LMUL2-NEXT: br i1 [[CMP10]], label [[FOR_BODY_PREHEADER:%.*]], label [[FOR_END:%.*]]
; LMUL2: for.body.preheader:
; LMUL2-NEXT: [[TMP0:%.*]] = add i32 [[SIZE]], -1
; LMUL2-NEXT: [[TMP1:%.*]] = zext i32 [[TMP0]] to i64
; LMUL2-NEXT: [[TMP2:%.*]] = add nuw nsw i64 [[TMP1]], 1
; LMUL2-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP2]], 8
; LMUL2-NEXT: [[TMP0:%.*]] = zext i32 [[SIZE]] to i64
; LMUL2-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], 8
; LMUL2-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; LMUL2: vector.ph:
; LMUL2-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP2]], 8
; LMUL2-NEXT: [[N_VEC:%.*]] = sub i64 [[TMP2]], [[N_MOD_VF]]
; LMUL2-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP0]], 8
; LMUL2-NEXT: [[N_VEC:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]]
; LMUL2-NEXT: br label [[VECTOR_BODY:%.*]]
; LMUL2: vector.body:
; LMUL2-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; LMUL2-NEXT: [[TMP3:%.*]] = add i64 [[INDEX]], 0
; LMUL2-NEXT: [[TMP4:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[TMP3]]
; LMUL2-NEXT: [[TMP1:%.*]] = add i64 [[INDEX]], 0
; LMUL2-NEXT: [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[TMP1]]
; LMUL2-NEXT: [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[TMP2]], i32 0
; LMUL2-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i32>, ptr [[TMP3]], align 4
; LMUL2-NEXT: [[TMP4:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], i64 [[TMP1]]
; LMUL2-NEXT: [[TMP5:%.*]] = getelementptr inbounds i32, ptr [[TMP4]], i32 0
; LMUL2-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i32>, ptr [[TMP5]], align 4
; LMUL2-NEXT: [[TMP6:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], i64 [[TMP3]]
; LMUL2-NEXT: [[TMP7:%.*]] = getelementptr inbounds i32, ptr [[TMP6]], i32 0
; LMUL2-NEXT: [[WIDE_LOAD1:%.*]] = load <8 x i32>, ptr [[TMP7]], align 4
; LMUL2-NEXT: [[TMP8:%.*]] = add nsw <8 x i32> [[WIDE_LOAD1]], [[WIDE_LOAD]]
; LMUL2-NEXT: [[TMP9:%.*]] = getelementptr inbounds i32, ptr [[C:%.*]], i64 [[TMP3]]
; LMUL2-NEXT: [[TMP10:%.*]] = getelementptr inbounds i32, ptr [[TMP9]], i32 0
; LMUL2-NEXT: store <8 x i32> [[TMP8]], ptr [[TMP10]], align 4
; LMUL2-NEXT: [[WIDE_LOAD1:%.*]] = load <8 x i32>, ptr [[TMP5]], align 4
; LMUL2-NEXT: [[TMP6:%.*]] = add nsw <8 x i32> [[WIDE_LOAD1]], [[WIDE_LOAD]]
; LMUL2-NEXT: [[TMP7:%.*]] = getelementptr inbounds i32, ptr [[C:%.*]], i64 [[TMP1]]
; LMUL2-NEXT: [[TMP8:%.*]] = getelementptr inbounds i32, ptr [[TMP7]], i32 0
; LMUL2-NEXT: store <8 x i32> [[TMP6]], ptr [[TMP8]], align 4
; LMUL2-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
; LMUL2-NEXT: [[TMP11:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; LMUL2-NEXT: br i1 [[TMP11]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; LMUL2-NEXT: [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; LMUL2-NEXT: br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; LMUL2: middle.block:
; LMUL2-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP2]], [[N_VEC]]
; LMUL2-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP0]], [[N_VEC]]
; LMUL2-NEXT: br i1 [[CMP_N]], label [[FOR_END_LOOPEXIT:%.*]], label [[SCALAR_PH]]
; LMUL2: scalar.ph:
; LMUL2-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[FOR_BODY_PREHEADER]] ]
; LMUL2-NEXT: br label [[FOR_BODY:%.*]]
; LMUL2: for.body:
; LMUL2-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
; LMUL2-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[A]], i64 [[INDVARS_IV]]
; LMUL2-NEXT: [[TMP12:%.*]] = load i32, ptr [[ARRAYIDX]], align 4
; LMUL2-NEXT: [[TMP10:%.*]] = load i32, ptr [[ARRAYIDX]], align 4
; LMUL2-NEXT: [[ARRAYIDX2:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[INDVARS_IV]]
; LMUL2-NEXT: [[TMP13:%.*]] = load i32, ptr [[ARRAYIDX2]], align 4
; LMUL2-NEXT: [[ADD:%.*]] = add nsw i32 [[TMP13]], [[TMP12]]
; LMUL2-NEXT: [[TMP11:%.*]] = load i32, ptr [[ARRAYIDX2]], align 4
; LMUL2-NEXT: [[ADD:%.*]] = add nsw i32 [[TMP11]], [[TMP10]]
; LMUL2-NEXT: [[ARRAYIDX4:%.*]] = getelementptr inbounds i32, ptr [[C]], i64 [[INDVARS_IV]]
; LMUL2-NEXT: store i32 [[ADD]], ptr [[ARRAYIDX4]], align 4
; LMUL2-NEXT: [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 1
Expand Down
Loading

0 comments on commit 4b47d87

Please sign in to comment.