diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index cc008d3337ad4..bfe1fb7817df3 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -685,12 +685,14 @@ static Instruction *shrinkInsertElt(CastInst &Trunc, return nullptr; } -Instruction *InstCombiner::visitTrunc(TruncInst &CI) { - if (Instruction *Result = commonCastTransforms(CI)) +Instruction *InstCombiner::visitTrunc(TruncInst &Trunc) { + if (Instruction *Result = commonCastTransforms(Trunc)) return Result; - Value *Src = CI.getOperand(0); - Type *DestTy = CI.getType(), *SrcTy = Src->getType(); + Value *Src = Trunc.getOperand(0); + Type *DestTy = Trunc.getType(), *SrcTy = Src->getType(); + unsigned DestWidth = DestTy->getScalarSizeInBits(); + unsigned SrcWidth = SrcTy->getScalarSizeInBits(); ConstantInt *Cst; // Attempt to truncate the entire input expression tree to the destination @@ -698,17 +700,17 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // expression tree to something weird like i93 unless the source is also // strange. if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) && - canEvaluateTruncated(Src, DestTy, *this, &CI)) { + canEvaluateTruncated(Src, DestTy, *this, &Trunc)) { // If this cast is a truncate, evaluting in a different type always // eliminates the cast, so it is always a win. LLVM_DEBUG( dbgs() << "ICE: EvaluateInDifferentType converting expression type" " to avoid cast: " - << CI << '\n'); + << Trunc << '\n'); Value *Res = EvaluateInDifferentType(Src, DestTy, false); assert(Res->getType() == DestTy); - return replaceInstUsesWith(CI, Res); + return replaceInstUsesWith(Trunc, Res); } // Test if the trunc is the user of a select which is part of a @@ -716,17 +718,17 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // Even simplifying demanded bits can break the canonical form of a // min/max. Value *LHS, *RHS; - if (SelectInst *SI = dyn_cast(CI.getOperand(0))) - if (matchSelectPattern(SI, LHS, RHS).Flavor != SPF_UNKNOWN) + if (SelectInst *Sel = dyn_cast(Src)) + if (matchSelectPattern(Sel, LHS, RHS).Flavor != SPF_UNKNOWN) return nullptr; // See if we can simplify any instructions used by the input whose sole // purpose is to compute bits we don't care about. - if (SimplifyDemandedInstructionBits(CI)) - return &CI; + if (SimplifyDemandedInstructionBits(Trunc)) + return &Trunc; - if (DestTy->getScalarSizeInBits() == 1) { - Value *Zero = Constant::getNullValue(Src->getType()); + if (DestWidth == 1) { + Value *Zero = Constant::getNullValue(SrcTy); if (DestTy->isIntegerTy()) { // Canonicalize trunc x to i1 -> icmp ne (and x, 1), 0 (scalar only). // TODO: We canonicalize to more instructions here because we are probably @@ -743,14 +745,14 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { const APInt *C; if (match(Src, m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) { // trunc (lshr X, C) to i1 --> icmp ne (and X, C'), 0 - APInt MaskC = APInt(SrcTy->getScalarSizeInBits(), 1).shl(*C); + APInt MaskC = APInt(SrcWidth, 1).shl(*C); Value *And = Builder.CreateAnd(X, ConstantInt::get(SrcTy, MaskC)); return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); } if (match(Src, m_OneUse(m_c_Or(m_LShr(m_Value(X), m_APInt(C)), m_Deferred(X))))) { // trunc (or (lshr X, C), X) to i1 --> icmp ne (and X, C'), 0 - APInt MaskC = APInt(SrcTy->getScalarSizeInBits(), 1).shl(*C) | 1; + APInt MaskC = APInt(SrcWidth, 1).shl(*C) | 1; Value *And = Builder.CreateAnd(X, ConstantInt::get(SrcTy, MaskC)); return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); } @@ -772,7 +774,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // If the shift amount is larger than the size of A, then the result is // known to be zero because all the input bits got shifted out. if (Cst->getZExtValue() >= ASize) - return replaceInstUsesWith(CI, Constant::getNullValue(DestTy)); + return replaceInstUsesWith(Trunc, Constant::getNullValue(DestTy)); // Since we're doing an lshr and a zero extend, and know that the shift // amount is smaller than ASize, it is always safe to do the shift in A's @@ -791,10 +793,8 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { if (Src->hasOneUse() && match(Src, m_LShr(m_SExt(m_Value(A)), m_ConstantInt(Cst)))) { Value *SExt = cast(Src)->getOperand(0); - const unsigned SExtSize = SExt->getType()->getPrimitiveSizeInBits(); - const unsigned ASize = A->getType()->getPrimitiveSizeInBits(); - const unsigned CISize = CI.getType()->getPrimitiveSizeInBits(); - const unsigned MaxAmt = SExtSize - std::max(CISize, ASize); + unsigned ASize = A->getType()->getPrimitiveSizeInBits(); + unsigned MaxAmt = SrcWidth - std::max(DestWidth, ASize); unsigned ShiftAmt = Cst->getZExtValue(); // This optimization can be only performed when zero bits generated by @@ -803,24 +803,24 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // FIXME: Instead of bailing when the shift is too large, use and to clear // the extra bits. if (ShiftAmt <= MaxAmt) { - if (CISize == ASize) - return BinaryOperator::CreateAShr(A, ConstantInt::get(CI.getType(), - std::min(ShiftAmt, ASize - 1))); + if (DestWidth == ASize) + return BinaryOperator::CreateAShr( + A, ConstantInt::get(DestTy, std::min(ShiftAmt, ASize - 1))); if (SExt->hasOneUse()) { Value *Shift = Builder.CreateAShr(A, std::min(ShiftAmt, ASize - 1)); Shift->takeName(Src); - return CastInst::CreateIntegerCast(Shift, CI.getType(), true); + return CastInst::CreateIntegerCast(Shift, DestTy, true); } } } - if (Instruction *I = narrowBinOp(CI)) + if (Instruction *I = narrowBinOp(Trunc)) return I; - if (Instruction *I = shrinkSplatShuffle(CI, Builder)) + if (Instruction *I = shrinkSplatShuffle(Trunc, Builder)) return I; - if (Instruction *I = shrinkInsertElt(CI, Builder)) + if (Instruction *I = shrinkInsertElt(Trunc, Builder)) return I; if (Src->hasOneUse() && isa(SrcTy) && @@ -831,18 +831,17 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { !match(A, m_Shr(m_Value(), m_Constant()))) { // Skip shifts of shift by constants. It undoes a combine in // FoldShiftByConstant and is the extend in reg pattern. - const unsigned DestSize = DestTy->getScalarSizeInBits(); - if (Cst->getValue().ult(DestSize)) { + if (Cst->getValue().ult(DestWidth)) { Value *NewTrunc = Builder.CreateTrunc(A, DestTy, A->getName() + ".tr"); return BinaryOperator::Create( Instruction::Shl, NewTrunc, - ConstantInt::get(DestTy, Cst->getValue().trunc(DestSize))); + ConstantInt::get(DestTy, Cst->getValue().trunc(DestWidth))); } } } - if (Instruction *I = foldVecTruncToExtElt(CI, *this)) + if (Instruction *I = foldVecTruncToExtElt(Trunc, *this)) return I; // Whenever an element is extracted from a vector, and then truncated, @@ -856,13 +855,11 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { Value *VecOp; if (match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst))))) { auto *VecOpTy = cast(VecOp->getType()); - unsigned DestScalarSize = DestTy->getScalarSizeInBits(); - unsigned VecOpScalarSize = VecOpTy->getScalarSizeInBits(); unsigned VecNumElts = VecOpTy->getNumElements(); // A badly fit destination size would result in an invalid cast. - if (VecOpScalarSize % DestScalarSize == 0) { - uint64_t TruncRatio = VecOpScalarSize / DestScalarSize; + if (SrcWidth % DestWidth == 0) { + uint64_t TruncRatio = SrcWidth / DestWidth; uint64_t BitCastNumElts = VecNumElts * TruncRatio; uint64_t VecOpIdx = Cst->getZExtValue(); uint64_t NewIdx = DL.isBigEndian() ? (VecOpIdx + 1) * TruncRatio - 1