Skip to content

Commit

Permalink
[SLP] Truncate expressions to minimum required bit width
Browse files Browse the repository at this point in the history
This change attempts to produce vectorized integer expressions in bit widths
that are narrower than their scalar counterparts. The need for demotion arises
especially on architectures in which the small integer types (e.g., i8 and i16)
are not legal for scalar operations but can still be used in vectors. Like
similar work done within the loop vectorizer, we rely on InstCombine to perform
the actual type-shrinking. We use the DemandedBits analysis and
ComputeNumSignBits from ValueTracking to determine the minimum required bit
width of an expression.

Differential revision: http://reviews.llvm.org/D15815

llvm-svn: 258404
  • Loading branch information
mssimpso committed Jan 21, 2016
1 parent 61c115f commit cb17d72
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 24 deletions.
154 changes: 143 additions & 11 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Expand Up @@ -15,21 +15,22 @@
// "Loop-Aware SLP in GCC" by Ira Rosen, Dorit Nuzman, Ayal Zaks.
//
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Vectorize.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/CodeMetrics.h"
#include "llvm/Analysis/DemandedBits.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
Expand All @@ -44,7 +45,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/Transforms/Vectorize.h"
#include <algorithm>
#include <map>
#include <memory>
Expand Down Expand Up @@ -363,11 +364,12 @@ class BoUpSLP {

BoUpSLP(Function *Func, ScalarEvolution *Se, TargetTransformInfo *Tti,
TargetLibraryInfo *TLi, AliasAnalysis *Aa, LoopInfo *Li,
DominatorTree *Dt, AssumptionCache *AC)
DominatorTree *Dt, AssumptionCache *AC, DemandedBits *DB)
: NumLoadsWantToKeepOrder(0), NumLoadsWantToChangeOrder(0), F(Func),
SE(Se), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt),
SE(Se), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt), AC(AC), DB(DB),
Builder(Se->getContext()) {
CodeMetrics::collectEphemeralValues(F, AC, EphValues);
MaxRequiredIntegerTy = nullptr;
}

/// \brief Vectorize the tree that starts with the elements in \p VL.
Expand Down Expand Up @@ -399,6 +401,7 @@ class BoUpSLP {
BlockScheduling *BS = Iter.second.get();
BS->clear();
}
MaxRequiredIntegerTy = nullptr;
}

/// \returns true if the memory operations A and B are consecutive.
Expand All @@ -419,6 +422,10 @@ class BoUpSLP {
/// vectorization factors.
unsigned getVectorElementSize(Value *V);

/// Compute the maximum width integer type required to represent the result
/// of a scalar expression, if such a type exists.
void computeMaxRequiredIntegerTy();

private:
struct TreeEntry;

Expand Down Expand Up @@ -924,8 +931,13 @@ class BoUpSLP {
AliasAnalysis *AA;
LoopInfo *LI;
DominatorTree *DT;
AssumptionCache *AC;
DemandedBits *DB;
/// Instruction builder to construct the vectorized tree.
IRBuilder<> Builder;

// The maximum width integer type required to represent a scalar expression.
IntegerType *MaxRequiredIntegerTy;
};

#ifndef NDEBUG
Expand Down Expand Up @@ -1481,6 +1493,15 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
ScalarTy = SI->getValueOperand()->getType();
VectorType *VecTy = VectorType::get(ScalarTy, VL.size());

// If we have computed a smaller type for the expression, update VecTy so
// that the costs will be accurate.
if (MaxRequiredIntegerTy) {
auto *IT = dyn_cast<IntegerType>(ScalarTy);
assert(IT && "Computed smaller type for non-integer value?");
if (MaxRequiredIntegerTy->getBitWidth() < IT->getBitWidth())
VecTy = VectorType::get(MaxRequiredIntegerTy, VL.size());
}

if (E->NeedToGather) {
if (allConstant(VL))
return 0;
Expand Down Expand Up @@ -1809,9 +1830,17 @@ int BoUpSLP::getTreeCost() {
if (EphValues.count(EU.User))
continue;

VectorType *VecTy = VectorType::get(EU.Scalar->getType(), BundleWidth);
ExtractCost += TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy,
EU.Lane);
// If we plan to rewrite the tree in a smaller type, we will need to sign
// extend the extracted value back to the original type. Here, we account
// for the extract and the added cost of the sign extend if needed.
auto *VecTy = VectorType::get(EU.Scalar->getType(), BundleWidth);
if (MaxRequiredIntegerTy) {
VecTy = VectorType::get(MaxRequiredIntegerTy, BundleWidth);
ExtractCost += TTI->getCastInstrCost(
Instruction::SExt, EU.Scalar->getType(), MaxRequiredIntegerTy);
}
ExtractCost +=
TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, EU.Lane);
}

Cost += getSpillCost();
Expand Down Expand Up @@ -2566,7 +2595,19 @@ Value *BoUpSLP::vectorizeTree() {
}

Builder.SetInsertPoint(&F->getEntryBlock().front());
vectorizeTree(&VectorizableTree[0]);
auto *VectorRoot = vectorizeTree(&VectorizableTree[0]);

// If the vectorized tree can be rewritten in a smaller type, we truncate the
// vectorized root. InstCombine will then rewrite the entire expression. We
// sign extend the extracted values below.
if (MaxRequiredIntegerTy) {
BasicBlock::iterator I(cast<Instruction>(VectorRoot));
Builder.SetInsertPoint(&*++I);
auto BundleWidth = VectorizableTree[0].Scalars.size();
auto *SmallerTy = VectorType::get(MaxRequiredIntegerTy, BundleWidth);
auto *Trunc = Builder.CreateTrunc(VectorRoot, SmallerTy);
VectorizableTree[0].VectorizedValue = Trunc;
}

DEBUG(dbgs() << "SLP: Extracting " << ExternalUses.size() << " values .\n");

Expand Down Expand Up @@ -2599,19 +2640,25 @@ Value *BoUpSLP::vectorizeTree() {
if (PH->getIncomingValue(i) == Scalar) {
Builder.SetInsertPoint(PH->getIncomingBlock(i)->getTerminator());
Value *Ex = Builder.CreateExtractElement(Vec, Lane);
if (MaxRequiredIntegerTy)
Ex = Builder.CreateSExt(Ex, Scalar->getType());
CSEBlocks.insert(PH->getIncomingBlock(i));
PH->setOperand(i, Ex);
}
}
} else {
Builder.SetInsertPoint(cast<Instruction>(User));
Value *Ex = Builder.CreateExtractElement(Vec, Lane);
if (MaxRequiredIntegerTy)
Ex = Builder.CreateSExt(Ex, Scalar->getType());
CSEBlocks.insert(cast<Instruction>(User)->getParent());
User->replaceUsesOfWith(Scalar, Ex);
}
} else {
Builder.SetInsertPoint(&F->getEntryBlock().front());
Value *Ex = Builder.CreateExtractElement(Vec, Lane);
if (MaxRequiredIntegerTy)
Ex = Builder.CreateSExt(Ex, Scalar->getType());
CSEBlocks.insert(&F->getEntryBlock());
User->replaceUsesOfWith(Scalar, Ex);
}
Expand Down Expand Up @@ -3180,7 +3227,7 @@ unsigned BoUpSLP::getVectorElementSize(Value *V) {
// If the current instruction is a load, update MaxWidth to reflect the
// width of the loaded value.
else if (isa<LoadInst>(I))
MaxWidth = std::max(MaxWidth, (unsigned)DL.getTypeSizeInBits(Ty));
MaxWidth = std::max<unsigned>(MaxWidth, DL.getTypeSizeInBits(Ty));

// Otherwise, we need to visit the operands of the instruction. We only
// handle the interesting cases from buildTree here. If an operand is an
Expand All @@ -3207,6 +3254,85 @@ unsigned BoUpSLP::getVectorElementSize(Value *V) {
return MaxWidth;
}

void BoUpSLP::computeMaxRequiredIntegerTy() {

// If there are no external uses, the expression tree must be rooted by a
// store. We can't demote in-memory values, so there is nothing to do here.
if (ExternalUses.empty())
return;

// If the expression is not rooted by a store, these roots should have
// external uses. We will rely on InstCombine to rewrite the expression in
// the narrower type. However, InstCombine only rewrites single-use values.
// This means that if a tree entry other than a root is used externally, it
// must have multiple uses and InstCombine will not rewrite it. The code
// below ensures that only the roots are used externally.
auto &TreeRoot = VectorizableTree[0].Scalars;
SmallPtrSet<Value *, 16> ScalarRoots(TreeRoot.begin(), TreeRoot.end());
for (auto &EU : ExternalUses)
if (!ScalarRoots.erase(EU.Scalar))
return;
if (!ScalarRoots.empty())
return;

// The maximum bit width required to represent all the instructions in the
// tree without loss of precision. It would be safe to truncate the
// expression to this width.
auto MaxBitWidth = 8u;

// We first check if all the bits of the root are demanded. If they're not,
// we can truncate the root to this narrower type.
auto *Root = dyn_cast<Instruction>(TreeRoot[0]);
if (!Root || !isa<IntegerType>(Root->getType()) || !Root->hasOneUse())
return;
auto Mask = DB->getDemandedBits(Root);
if (Mask.countLeadingZeros() > 0)
MaxBitWidth = Mask.getBitWidth() - Mask.countLeadingZeros();

// If all the bits of the root are demanded, we can try a little harder to
// compute a narrower type. This can happen, for example, if the roots are
// getelementptr indices. InstCombine promotes these indices to the pointer
// width. Thus, all their bits are technically demanded even though the
// address computation might be vectorized in a smaller type. We start by
// looking at each entry in the tree.
else
for (auto &Entry : VectorizableTree) {

// Get a representative value for the vectorizable bundle. All values in
// Entry.Scalars should be isomorphic.
auto *Scalar = Entry.Scalars[0];

// If the scalar is used more than once, InstCombine will not rewrite it,
// so we should give up.
if (!Scalar->hasOneUse())
return;

// We only compute smaller integer types. If the scalar has a different
// type, give up.
auto *IT = dyn_cast<IntegerType>(Scalar->getType());
if (!IT)
return;

// Compute the maximum bit width required to store the scalar. We use
// ValueTracking to compute the number of high-order bits we can
// truncate. We then round up to the next power-of-two.
auto &DL = F->getParent()->getDataLayout();
auto NumSignBits = ComputeNumSignBits(Scalar, DL, 0, AC, 0, DT);
auto NumTypeBits = IT->getBitWidth();
MaxBitWidth = std::max<unsigned>(NumTypeBits - NumSignBits, MaxBitWidth);
}

// Round up to the next power-of-two.
if (!isPowerOf2_64(MaxBitWidth))
MaxBitWidth = NextPowerOf2(MaxBitWidth);

// If the maximum bit width we compute is less than the with of the roots'
// type, we can proceed with the narrowing. Otherwise, do nothing.
auto *RootIT = cast<IntegerType>(TreeRoot[0]->getType());
if (MaxBitWidth > 0 && MaxBitWidth < RootIT->getBitWidth())
MaxRequiredIntegerTy = IntegerType::get(F->getContext(), MaxBitWidth);
}

/// The SLPVectorizer Pass.
struct SLPVectorizer : public FunctionPass {
typedef SmallVector<StoreInst *, 8> StoreList;
Expand All @@ -3228,6 +3354,7 @@ struct SLPVectorizer : public FunctionPass {
LoopInfo *LI;
DominatorTree *DT;
AssumptionCache *AC;
DemandedBits *DB;

bool runOnFunction(Function &F) override {
if (skipOptnoneFunction(F))
Expand All @@ -3241,6 +3368,7 @@ struct SLPVectorizer : public FunctionPass {
LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
DB = &getAnalysis<DemandedBits>();

Stores.clear();
GEPs.clear();
Expand Down Expand Up @@ -3270,7 +3398,7 @@ struct SLPVectorizer : public FunctionPass {

// Use the bottom up slp vectorizer to construct chains that start with
// store instructions.
BoUpSLP R(&F, SE, TTI, TLI, AA, LI, DT, AC);
BoUpSLP R(&F, SE, TTI, TLI, AA, LI, DT, AC, DB);

// A general note: the vectorizer must use BoUpSLP::eraseInstruction() to
// delete instructions.
Expand Down Expand Up @@ -3313,6 +3441,7 @@ struct SLPVectorizer : public FunctionPass {
AU.addRequired<TargetTransformInfoWrapperPass>();
AU.addRequired<LoopInfoWrapperPass>();
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<DemandedBits>();
AU.addPreserved<LoopInfoWrapperPass>();
AU.addPreserved<DominatorTreeWrapperPass>();
AU.addPreserved<AAResultsWrapperPass>();
Expand Down Expand Up @@ -3417,6 +3546,7 @@ bool SLPVectorizer::vectorizeStoreChain(ArrayRef<Value *> Chain,
ArrayRef<Value *> Operands = Chain.slice(i, VF);

R.buildTree(Operands);
R.computeMaxRequiredIntegerTy();

int Cost = R.getTreeCost();

Expand Down Expand Up @@ -3616,6 +3746,7 @@ bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R,
Value *ReorderedOps[] = { Ops[1], Ops[0] };
R.buildTree(ReorderedOps, None);
}
R.computeMaxRequiredIntegerTy();
int Cost = R.getTreeCost();

if (Cost < -SLPCostThreshold) {
Expand Down Expand Up @@ -3882,6 +4013,7 @@ class HorizontalReduction {

for (; i < NumReducedVals - ReduxWidth + 1; i += ReduxWidth) {
V.buildTree(makeArrayRef(&ReducedVals[i], ReduxWidth), ReductionOps);
V.computeMaxRequiredIntegerTy();

// Estimate cost.
int Cost = V.getTreeCost() + getReductionCost(TTI, ReducedVals[i]);
Expand Down
31 changes: 18 additions & 13 deletions llvm/test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll
@@ -1,4 +1,5 @@
; RUN: opt -S -slp-vectorizer -dce -instcombine < %s | FileCheck %s
; RUN: opt -S -slp-vectorizer -dce -instcombine < %s | FileCheck %s --check-prefix=PROFITABLE
; RUN: opt -S -slp-vectorizer -slp-threshold=-12 -dce -instcombine < %s | FileCheck %s --check-prefix=UNPROFITABLE

target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128"
target triple = "aarch64--linux-gnu"
Expand All @@ -18,13 +19,13 @@ target triple = "aarch64--linux-gnu"
; return sum;
; }

; CHECK-LABEL: @gather_reduce_8x16_i32
; PROFITABLE-LABEL: @gather_reduce_8x16_i32
;
; CHECK: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16>
; CHECK: zext <8 x i16> [[L]] to <8 x i32>
; CHECK: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32>
; CHECK: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]]
; CHECK: sext i32 [[X]] to i64
; PROFITABLE: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16>
; PROFITABLE: zext <8 x i16> [[L]] to <8 x i32>
; PROFITABLE: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32>
; PROFITABLE: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]]
; PROFITABLE: sext i32 [[X]] to i64
;
define i32 @gather_reduce_8x16_i32(i16* nocapture readonly %a, i16* nocapture readonly %b, i16* nocapture readonly %g, i32 %n) {
entry:
Expand Down Expand Up @@ -137,14 +138,18 @@ for.body:
br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body
}

; CHECK-LABEL: @gather_reduce_8x16_i64
; UNPROFITABLE-LABEL: @gather_reduce_8x16_i64
;
; CHECK-NOT: load <8 x i16>
; UNPROFITABLE: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16>
; UNPROFITABLE: zext <8 x i16> [[L]] to <8 x i32>
; UNPROFITABLE: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32>
; UNPROFITABLE: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]]
; UNPROFITABLE: sext i32 [[X]] to i64
;
; FIXME: We are currently unable to vectorize the case with i64 subtraction
; because the zero extensions are too expensive. The solution here is to
; convert the i64 subtractions to i32 subtractions during vectorization.
; This would then match the case above.
; TODO: Although we can now vectorize this case while converting the i64
; subtractions to i32, the cost model currently finds vectorization to be
; unprofitable. The cost model is penalizing the sign and zero
; extensions in the vectorized version, but they are actually free.
;
define i32 @gather_reduce_8x16_i64(i16* nocapture readonly %a, i16* nocapture readonly %b, i16* nocapture readonly %g, i32 %n) {
entry:
Expand Down

0 comments on commit cb17d72

Please sign in to comment.