diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h index 640e1fda2119..7a69d20154b0 100644 --- a/llvm/include/llvm/Analysis/InstructionSimplify.h +++ b/llvm/include/llvm/Analysis/InstructionSimplify.h @@ -123,6 +123,14 @@ struct SimplifyQuery { Copy.CanUseUndef = false; return Copy; } + + /// If CanUseUndef is true, returns whether \p V is undef. + /// Otherwise always return false. + bool isUndefValue(Value *V) const { + if (!CanUseUndef) + return false; + return isa(V); + } }; // NOTE: the explicit multiple argument versions of these functions are diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 893de596e1a7..a5b9cbf3f03c 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -415,9 +415,9 @@ static Value *ThreadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, return TV; // If one branch simplified to undef, return the other one. - if (TV && Q.CanUseUndef && isa(TV)) + if (TV && Q.isUndefValue(TV)) return FV; - if (FV && Q.CanUseUndef && isa(FV)) + if (FV && Q.isUndefValue(FV)) return TV; // If applying the operation did not change the true and false select values, @@ -612,7 +612,7 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, return C; // X + undef -> undef - if (Q.CanUseUndef && match(Op1, m_Undef())) + if (Q.isUndefValue(Op1)) return Op1; // X + 0 -> X @@ -732,7 +732,7 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, // X - undef -> undef // undef - X -> undef - if (match(Op0, m_Undef()) || match(Op1, m_Undef())) + if (Q.isUndefValue(Op0) || Q.isUndefValue(Op1)) return UndefValue::get(Op0->getType()); // X - 0 -> X @@ -867,7 +867,7 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // X * undef -> 0 // X * 0 -> 0 - if (Q.CanUseUndef && match(Op1, m_CombineOr(m_Undef(), m_Zero()))) + if (Q.isUndefValue(Op1) || match(Op1, m_Zero())) return Constant::getNullValue(Op0->getType()); // X * 1 -> X @@ -920,12 +920,13 @@ Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { /// Check for common or similar folds of integer division or integer remainder. /// This applies to all 4 opcodes (sdiv/udiv/srem/urem). -static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) { +static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv, + const SimplifyQuery &Q) { Type *Ty = Op0->getType(); // X / undef -> undef // X % undef -> undef - if (match(Op1, m_Undef())) + if (Q.isUndefValue(Op1)) return Op1; // X / 0 -> undef @@ -942,14 +943,14 @@ static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) { unsigned NumElts = VTy->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { Constant *Elt = Op1C->getAggregateElement(i); - if (Elt && (Elt->isNullValue() || isa(Elt))) + if (Elt && (Elt->isNullValue() || Q.isUndefValue(Elt))) return UndefValue::get(Ty); } } // undef / X -> 0 // undef % X -> 0 - if (match(Op0, m_Undef())) + if (Q.isUndefValue(Op0)) return Constant::getNullValue(Ty); // 0 / X -> 0 @@ -1043,7 +1044,7 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) return C; - if (Value *V = simplifyDivRem(Op0, Op1, true)) + if (Value *V = simplifyDivRem(Op0, Op1, true, Q)) return V; bool IsSigned = Opcode == Instruction::SDiv; @@ -1101,7 +1102,7 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) return C; - if (Value *V = simplifyDivRem(Op0, Op1, false)) + if (Value *V = simplifyDivRem(Op0, Op1, false, Q)) return V; // (X % Y) % Y -> X % Y @@ -1197,13 +1198,13 @@ Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { } /// Returns true if a shift by \c Amount always yields undef. -static bool isUndefShift(Value *Amount) { +static bool isUndefShift(Value *Amount, const SimplifyQuery &Q) { Constant *C = dyn_cast(Amount); if (!C) return false; // X shift by undef -> undef because it may shift by the bitwidth. - if (isa(C)) + if (Q.isUndefValue(C)) return true; // Shifting by the bitwidth or more is undefined. @@ -1217,7 +1218,7 @@ static bool isUndefShift(Value *Amount) { for (unsigned I = 0, E = cast(C->getType())->getNumElements(); I != E; ++I) - if (!isUndefShift(C->getAggregateElement(I))) + if (!isUndefShift(C->getAggregateElement(I), Q)) return false; return true; } @@ -1245,7 +1246,7 @@ static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, return Op0; // Fold undefined shifts. - if (isUndefShift(Op1)) + if (isUndefShift(Op1, Q)) return UndefValue::get(Op0->getType()); // If the operation is with the result of a select instruction, check whether @@ -1289,7 +1290,7 @@ static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, // undef >> X -> 0 // undef >> X -> undef (if it's exact) - if (Q.CanUseUndef && match(Op0, m_Undef())) + if (Q.isUndefValue(Op0)) return isExact ? Op0 : Constant::getNullValue(Op0->getType()); // The low bit cannot be shifted out of an exact shift if it is set. @@ -1311,7 +1312,7 @@ static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, // undef << X -> 0 // undef << X -> undef if (if it's NSW/NUW) - if (Q.CanUseUndef && match(Op0, m_Undef())) + if (Q.isUndefValue(Op0)) return isNSW || isNUW ? Op0 : Constant::getNullValue(Op0->getType()); // (X >> A) << A -> X @@ -2004,7 +2005,7 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return C; // X & undef -> 0 - if (Q.CanUseUndef && match(Op1, m_Undef())) + if (Q.isUndefValue(Op1)) return Constant::getNullValue(Op0->getType()); // X & X = X @@ -2162,7 +2163,7 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // X | undef -> -1 // X | -1 = -1 // Do not return Op1 because it may contain undef elements if it's a vector. - if ((Q.CanUseUndef && match(Op1, m_Undef())) || match(Op1, m_AllOnes())) + if (Q.isUndefValue(Op1) || match(Op1, m_AllOnes())) return Constant::getAllOnesValue(Op0->getType()); // X | X = X @@ -2304,7 +2305,7 @@ static Value *SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return C; // A ^ undef -> undef - if (Q.CanUseUndef && match(Op1, m_Undef())) + if (Q.isUndefValue(Op1)) return Op1; // A ^ 0 = A @@ -3270,12 +3271,12 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // For EQ and NE, we can always pick a value for the undef to make the // predicate pass or fail, so we can return undef. // Matches behavior in llvm::ConstantFoldCompareInstruction. - if (Q.CanUseUndef && isa(RHS) && ICmpInst::isEquality(Pred)) + if (Q.isUndefValue(RHS) && ICmpInst::isEquality(Pred)) return UndefValue::get(ITy); // icmp X, X -> true/false // icmp X, undef -> true/false because undef could be X. - if (LHS == RHS || (Q.CanUseUndef && isa(RHS))) + if (LHS == RHS || Q.isUndefValue(RHS)) return ConstantInt::get(ITy, CmpInst::isTrueWhenEqual(Pred)); if (Value *V = simplifyICmpOfBools(Pred, LHS, RHS, Q)) @@ -3602,7 +3603,7 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, // fcmp pred x, undef and fcmp pred undef, x // fold to true if unordered, false if ordered - if (Q.CanUseUndef && (isa(LHS) || isa(RHS))) { + if (Q.isUndefValue(LHS) || Q.isUndefValue(RHS)) { // Choosing NaN for the undef will always make unordered comparison succeed // and ordered comparison fail. return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred)); @@ -4050,7 +4051,7 @@ static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, return ConstantFoldSelectInstruction(CondC, TrueC, FalseC); // select undef, X, Y -> X or Y - if (Q.CanUseUndef && isa(CondC)) + if (Q.isUndefValue(CondC)) return isa(FalseVal) ? FalseVal : TrueVal; // TODO: Vector constants with undef elements don't simplify. @@ -4076,9 +4077,9 @@ static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, if (TrueVal == FalseVal) return TrueVal; - if (Q.CanUseUndef && isa(TrueVal)) // select ?, undef, X -> X + if (Q.isUndefValue(TrueVal)) // select ?, undef, X -> X return FalseVal; - if (Q.CanUseUndef && isa(FalseVal)) // select ?, X, undef -> X + if (Q.isUndefValue(FalseVal)) // select ?, X, undef -> X return TrueVal; // Deal with partial undef vector constants: select ?, VecC, VecC' --> VecC'' @@ -4099,9 +4100,9 @@ static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, // one element is undef, choose the defined element as the safe result. if (TEltC == FEltC) NewC.push_back(TEltC); - else if (Q.CanUseUndef && isa(TEltC)) + else if (Q.isUndefValue(TEltC)) NewC.push_back(FEltC); - else if (Q.CanUseUndef && isa(FEltC)) + else if (Q.isUndefValue(FEltC)) NewC.push_back(TEltC); else break; @@ -4152,7 +4153,7 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef Ops, else if (VectorType *VT = dyn_cast(Ops[1]->getType())) GEPTy = VectorType::get(GEPTy, VT->getElementCount()); - if (Q.CanUseUndef && isa(Ops[0])) + if (Q.isUndefValue(Ops[0])) return UndefValue::get(GEPTy); bool IsScalableVec = isa(SrcTy); @@ -4261,7 +4262,7 @@ static Value *SimplifyInsertValueInst(Value *Agg, Value *Val, return ConstantFoldInsertValueInstruction(CAgg, CVal, Idxs); // insertvalue x, undef, n -> x - if (Q.CanUseUndef && match(Val, m_Undef())) + if (Q.isUndefValue(Val)) return Agg; // insertvalue x, (extractvalue y, n), n @@ -4269,7 +4270,7 @@ static Value *SimplifyInsertValueInst(Value *Agg, Value *Val, if (EV->getAggregateOperand()->getType() == Agg->getType() && EV->getIndices() == Idxs) { // insertvalue undef, (extractvalue y, n), n -> y - if (Q.CanUseUndef && match(Agg, m_Undef())) + if (Q.isUndefValue(Agg)) return EV->getAggregateOperand(); // insertvalue y, (extractvalue y, n), n -> y @@ -4303,12 +4304,12 @@ Value *llvm::SimplifyInsertElementInst(Value *Vec, Value *Val, Value *Idx, } // If index is undef, it might be out of bounds (see above case) - if (Q.CanUseUndef && isa(Idx)) + if (Q.isUndefValue(Idx)) return UndefValue::get(Vec->getType()); // If the scalar is undef, and there is no risk of propagating poison from the // vector value, simplify to the vector value. - if (Q.CanUseUndef && isa(Val) && + if (Q.isUndefValue(Val) && isGuaranteedNotToBeUndefOrPoison(Vec)) return Vec; @@ -4364,7 +4365,7 @@ static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, if (auto *Splat = CVec->getSplatValue()) return Splat; - if (Q.CanUseUndef && isa(Vec)) + if (Q.isUndefValue(Vec)) return UndefValue::get(VecVTy->getElementType()); } @@ -4381,7 +4382,7 @@ static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, // An undef extract index can be arbitrarily chosen to be an out-of-range // index value, which would result in the instruction being undef. - if (Q.CanUseUndef && isa(Idx)) + if (Q.isUndefValue(Idx)) return UndefValue::get(VecVTy->getElementType()); return nullptr; @@ -4401,7 +4402,7 @@ static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) { for (Value *Incoming : PN->incoming_values()) { // If the incoming value is the phi node itself, it can safely be skipped. if (Incoming == PN) continue; - if (Q.CanUseUndef && isa(Incoming)) { + if (Q.isUndefValue(Incoming)) { // Remember that we saw an undef value, but otherwise ignore them. HasUndefInput = true; continue; @@ -4596,7 +4597,7 @@ static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, // A shuffle of a splat is always the splat itself. Legal if the shuffle's // value type is same as the input vectors' type. if (auto *OpShuf = dyn_cast(Op0)) - if (Q.CanUseUndef && isa(Op1) && RetTy == InVecTy && + if (Q.isUndefValue(Op1) && RetTy == InVecTy && is_splat(OpShuf->getShuffleMask())) return Op0; @@ -4677,11 +4678,12 @@ static Constant *propagateNaN(Constant *In) { /// transforms based on undef/NaN because the operation itself makes no /// difference to the result. static Constant *simplifyFPOp(ArrayRef Ops, - FastMathFlags FMF = FastMathFlags()) { + FastMathFlags FMF, + const SimplifyQuery &Q) { for (Value *V : Ops) { bool IsNan = match(V, m_NaN()); bool IsInf = match(V, m_Inf()); - bool IsUndef = match(V, m_Undef()); + bool IsUndef = Q.isUndefValue(V); // If this operation has 'nnan' or 'ninf' and at least 1 disallowed operand // (an undef operand can be chosen to be Nan/Inf), then the result of @@ -4704,7 +4706,7 @@ static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q)) return C; - if (Constant *C = simplifyFPOp({Op0, Op1}, FMF)) + if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q)) return C; // fadd X, -0 ==> X @@ -4751,7 +4753,7 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) return C; - if (Constant *C = simplifyFPOp({Op0, Op1}, FMF)) + if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q)) return C; // fsub X, +0 ==> X @@ -4793,7 +4795,7 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, static Value *SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Constant *C = simplifyFPOp({Op0, Op1}, FMF)) + if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q)) return C; // fmul X, 1.0 ==> X @@ -4860,7 +4862,7 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q)) return C; - if (Constant *C = simplifyFPOp({Op0, Op1}, FMF)) + if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q)) return C; // X / 1.0 -> X @@ -4905,7 +4907,7 @@ static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q)) return C; - if (Constant *C = simplifyFPOp({Op0, Op1}, FMF)) + if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q)) return C; // Unlike fdiv, the result of frem always matches the sign of the dividend. @@ -5277,7 +5279,7 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, std::swap(Op0, Op1); // Assume undef is the limit value. - if (isa(Op1)) + if (Q.isUndefValue(Op1)) return ConstantInt::get(ReturnType, getMaxMinLimit(IID, BitWidth)); const APInt *C; @@ -5328,7 +5330,7 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, // undef - X -> { undef, false } // X + undef -> { undef, false } // undef + x -> { undef, false } - if (isa(Op0) || isa(Op1)) { + if (Q.isUndefValue(Op0) || Q.isUndefValue(Op1)) { return ConstantStruct::get( cast(ReturnType), {UndefValue::get(ReturnType->getStructElementType(0)), @@ -5343,7 +5345,7 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, return Constant::getNullValue(ReturnType); // undef * X -> { 0, false } // X * undef -> { 0, false } - if (match(Op0, m_Undef()) || match(Op1, m_Undef())) + if (Q.isUndefValue(Op0) || Q.isUndefValue(Op1)) return Constant::getNullValue(ReturnType); break; case Intrinsic::uadd_sat: @@ -5357,7 +5359,7 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, // sat(undef + X) -> -1 // For unsigned: Assume undef is MAX, thus we saturate to MAX (-1). // For signed: Assume undef is ~X, in which case X + ~X = -1. - if (match(Op0, m_Undef()) || match(Op1, m_Undef())) + if (Q.isUndefValue(Op0) || Q.isUndefValue(Op1)) return Constant::getAllOnesValue(ReturnType); // X + 0 -> X @@ -5374,7 +5376,7 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, LLVM_FALLTHROUGH; case Intrinsic::ssub_sat: // X - X -> 0, X - undef -> 0, undef - X -> 0 - if (Op0 == Op1 || match(Op0, m_Undef()) || match(Op1, m_Undef())) + if (Op0 == Op1 || Q.isUndefValue(Op0) || Q.isUndefValue(Op1)) return Constant::getNullValue(ReturnType); // X - 0 -> X if (match(Op1, m_Zero())) @@ -5413,9 +5415,9 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, if (Op0 == Op1) return Op0; // If one argument is undef, return the other argument. - if (match(Op0, m_Undef())) + if (Q.isUndefValue(Op0)) return Op1; - if (match(Op1, m_Undef())) + if (Q.isUndefValue(Op1)) return Op0; // If one argument is NaN, return other or NaN appropriately. @@ -5492,11 +5494,11 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { *ShAmtArg = Call->getArgOperand(2); // If both operands are undef, the result is undef. - if (Q.CanUseUndef && match(Op0, m_Undef()) && match(Op1, m_Undef())) + if (Q.isUndefValue(Op0) && Q.isUndefValue(Op1)) return UndefValue::get(F->getReturnType()); // If shift amount is undef, assume it is zero. - if (Q.CanUseUndef && match(ShAmtArg, m_Undef())) + if (Q.isUndefValue(ShAmtArg)) return Call->getArgOperand(IID == Intrinsic::fshl ? 0 : 1); const APInt *ShAmtC; @@ -5513,7 +5515,7 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { Value *Op0 = Call->getArgOperand(0); Value *Op1 = Call->getArgOperand(1); Value *Op2 = Call->getArgOperand(2); - if (Value *V = simplifyFPOp({ Op0, Op1, Op2 })) + if (Value *V = simplifyFPOp({ Op0, Op1, Op2 }, {}, Q)) return V; return nullptr; }