diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index e0a779467b1fa..bde65717ac1d4 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -7888,7 +7888,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef VectorizedVals, unsigned BWSz = DL->getTypeSizeInBits(ScalarTy); unsigned SrcBWSz = DL->getTypeSizeInBits(UserScalarTy); unsigned VecOpcode; - auto *SrcVecTy = + auto *UserVecTy = FixedVectorType::get(UserScalarTy, E->getVectorFactor()); if (BWSz > SrcBWSz) VecOpcode = Instruction::Trunc; @@ -7896,11 +7896,10 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef VectorizedVals, VecOpcode = It->second.second ? Instruction::SExt : Instruction::ZExt; TTI::CastContextHint CCH = GetCastContextHint(VL0); - VecCost += TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH, + VecCost += TTI->getCastInstrCost(VecOpcode, UserVecTy, VecTy, CCH, CostKind); - ScalarCost += - Sz * TTI->getCastInstrCost(VecOpcode, ScalarTy, UserScalarTy, - CCH, CostKind); + ScalarCost += Sz * TTI->getCastInstrCost(VecOpcode, UserScalarTy, + ScalarTy, CCH, CostKind); } } } @@ -8981,7 +8980,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef VectorizedVals) { SmallVector> FirstUsers; SmallVector DemandedElts; SmallDenseSet UsedInserts; - DenseSet VectorCasts; + DenseSet> VectorCasts; for (ExternalUser &EU : ExternalUses) { // We only add extract cost once for the same scalar. if (!isa_and_nonnull(EU.User) && @@ -9051,11 +9050,14 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef VectorizedVals) { DemandedElts.push_back(APInt::getZero(FTy->getNumElements())); VecId = FirstUsers.size() - 1; auto It = MinBWs.find(ScalarTE); - if (It != MinBWs.end() && VectorCasts.insert(EU.Scalar).second) { + if (It != MinBWs.end() && + VectorCasts + .insert(std::make_pair(ScalarTE, FTy->getElementType())) + .second) { unsigned BWSz = It->second.second; - unsigned SrcBWSz = DL->getTypeSizeInBits(FTy->getElementType()); + unsigned DstBWSz = DL->getTypeSizeInBits(FTy->getElementType()); unsigned VecOpcode; - if (BWSz < SrcBWSz) + if (DstBWSz < BWSz) VecOpcode = Instruction::Trunc; else VecOpcode = @@ -9108,17 +9110,20 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef VectorizedVals) { } // Add reduced value cost, if resized. if (!VectorizedVals.empty()) { - auto BWIt = MinBWs.find(VectorizableTree.front().get()); + const TreeEntry &Root = *VectorizableTree.front().get(); + auto BWIt = MinBWs.find(&Root); if (BWIt != MinBWs.end()) { - Type *DstTy = VectorizableTree.front()->Scalars.front()->getType(); + Type *DstTy = Root.Scalars.front()->getType(); unsigned OriginalSz = DL->getTypeSizeInBits(DstTy); - unsigned Opcode = Instruction::Trunc; - if (OriginalSz < BWIt->second.first) - Opcode = BWIt->second.second ? Instruction::SExt : Instruction::ZExt; - Type *SrcTy = IntegerType::get(DstTy->getContext(), BWIt->second.first); - Cost += TTI->getCastInstrCost(Opcode, DstTy, SrcTy, - TTI::CastContextHint::None, - TTI::TCK_RecipThroughput); + if (OriginalSz != BWIt->second.first) { + unsigned Opcode = Instruction::Trunc; + if (OriginalSz < BWIt->second.first) + Opcode = BWIt->second.second ? Instruction::SExt : Instruction::ZExt; + Type *SrcTy = IntegerType::get(DstTy->getContext(), BWIt->second.first); + Cost += TTI->getCastInstrCost(Opcode, DstTy, SrcTy, + TTI::CastContextHint::None, + TTI::TCK_RecipThroughput); + } } } @@ -11419,9 +11424,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { VecOpcode = Instruction::BitCast; } else if (BWSz < SrcBWSz) { VecOpcode = Instruction::Trunc; - } else if (It != MinBWs.end()) { + } else if (SrcIt != MinBWs.end()) { assert(BWSz > SrcBWSz && "Invalid cast!"); - VecOpcode = It->second.second ? Instruction::SExt : Instruction::ZExt; + VecOpcode = + SrcIt->second.second ? Instruction::SExt : Instruction::ZExt; } } Value *V = (VecOpcode != ShuffleOrOp && VecOpcode == Instruction::BitCast) @@ -11929,7 +11935,7 @@ Value *BoUpSLP::vectorizeTree( // basic block. Only one extractelement per block should be emitted. DenseMap> ScalarToEEs; SmallDenseSet UsedInserts; - DenseMap VectorCasts; + DenseMap, Value *> VectorCasts; SmallDenseSet ScalarsWithNullptrUser; // Extract all of the elements with the external uses. for (const auto &ExternalUse : ExternalUses) { @@ -12050,7 +12056,9 @@ Value *BoUpSLP::vectorizeTree( // Need to use original vector, if the root is truncated. auto BWIt = MinBWs.find(E); if (BWIt != MinBWs.end() && Vec->getType() != VU->getType()) { - auto VecIt = VectorCasts.find(Scalar); + auto *ScalarTy = FTy->getElementType(); + auto Key = std::make_pair(Vec, ScalarTy); + auto VecIt = VectorCasts.find(Key); if (VecIt == VectorCasts.end()) { IRBuilder<>::InsertPointGuard Guard(Builder); if (auto *IVec = dyn_cast(Vec)) @@ -12058,10 +12066,10 @@ Value *BoUpSLP::vectorizeTree( Vec = Builder.CreateIntCast( Vec, FixedVectorType::get( - cast(VU->getType())->getElementType(), + ScalarTy, cast(Vec->getType())->getNumElements()), BWIt->second.second); - VectorCasts.try_emplace(Scalar, Vec); + VectorCasts.try_emplace(Key, Vec); } else { Vec = VecIt->second; } diff --git a/llvm/test/Transforms/SLPVectorizer/SystemZ/minbitwidth-trunc.ll b/llvm/test/Transforms/SLPVectorizer/SystemZ/minbitwidth-trunc.ll new file mode 100644 index 0000000000000..e1942eb326079 --- /dev/null +++ b/llvm/test/Transforms/SLPVectorizer/SystemZ/minbitwidth-trunc.ll @@ -0,0 +1,48 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4 +; RUN: opt -S --passes=slp-vectorizer -mtriple=s390x-unknown-linux -mcpu=z14 < %s | FileCheck %s + +define void @test() { +; CHECK-LABEL: define void @test( +; CHECK-SAME: ) #[[ATTR0:[0-9]+]] { +; CHECK-NEXT: [[TMP1:%.*]] = zext i8 0 to i32 +; CHECK-NEXT: [[TMP2:%.*]] = zext i8 0 to i32 +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x i32> , i32 [[TMP2]], i32 1 +; CHECK-NEXT: [[TMP4:%.*]] = select <4 x i1> zeroinitializer, <4 x i32> zeroinitializer, <4 x i32> [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = select i1 false, i32 0, i32 0 +; CHECK-NEXT: [[TMP6:%.*]] = select i1 false, i32 0, i32 [[TMP1]] +; CHECK-NEXT: [[TMP7:%.*]] = select i1 false, i32 0, i32 [[TMP2]] +; CHECK-NEXT: [[TMP8:%.*]] = call i32 @llvm.vector.reduce.xor.v4i32(<4 x i32> [[TMP4]]) +; CHECK-NEXT: [[OP_RDX:%.*]] = xor i32 [[TMP8]], [[TMP5]] +; CHECK-NEXT: [[OP_RDX1:%.*]] = xor i32 [[TMP6]], [[TMP7]] +; CHECK-NEXT: [[OP_RDX2:%.*]] = xor i32 [[OP_RDX]], [[OP_RDX1]] +; CHECK-NEXT: [[TMP9:%.*]] = trunc i32 [[OP_RDX2]] to i16 +; CHECK-NEXT: store i16 [[TMP9]], ptr null, align 2 +; CHECK-NEXT: ret void +; + %1 = zext i8 0 to i32 + %.not = icmp sgt i32 0, %1 + %2 = zext i8 0 to i32 + %3 = select i1 %.not, i32 0, i32 0 + %4 = zext i8 0 to i32 + %.not.1 = icmp sgt i32 0, %4 + %5 = zext i8 0 to i32 + %6 = select i1 %.not.1, i32 0, i32 %5 + %7 = xor i32 %6, %3 + %8 = zext i8 0 to i32 + %.not.2 = icmp sgt i32 0, %8 + %9 = select i1 %.not.2, i32 0, i32 0 + %10 = xor i32 %9, %7 + %11 = zext i8 0 to i32 + %.not.3 = icmp sgt i32 0, %11 + %12 = select i1 %.not.3, i32 0, i32 0 + %13 = xor i32 %12, %10 + %14 = select i1 false, i32 0, i32 0 + %15 = xor i32 %14, %13 + %16 = select i1 false, i32 0, i32 %2 + %17 = xor i32 %16, %15 + %18 = select i1 false, i32 0, i32 %5 + %19 = xor i32 %18, %17 + %20 = trunc i32 %19 to i16 + store i16 %20, ptr null, align 2 + ret void +}