diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp index ae926f95cefe18..16b0880ce2f978 100644 --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -1704,7 +1704,7 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2, return ICmpInst::BAD_ICMP_PREDICATE; } -Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, +Constant *llvm::ConstantFoldCompareInstruction(CmpInst::Predicate Predicate, Constant *C1, Constant *C2) { Type *ResultTy; if (VectorType *VT = dyn_cast(C1->getType())) @@ -1714,10 +1714,10 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, ResultTy = Type::getInt1Ty(C1->getContext()); // Fold FCMP_FALSE/FCMP_TRUE unconditionally. - if (pred == FCmpInst::FCMP_FALSE) + if (Predicate == FCmpInst::FCMP_FALSE) return Constant::getNullValue(ResultTy); - if (pred == FCmpInst::FCMP_TRUE) + if (Predicate == FCmpInst::FCMP_TRUE) return Constant::getAllOnesValue(ResultTy); // Handle some degenerate cases first @@ -1725,7 +1725,6 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, return PoisonValue::get(ResultTy); if (isa(C1) || isa(C2)) { - CmpInst::Predicate Predicate = CmpInst::Predicate(pred); bool isIntegerPredicate = ICmpInst::isIntPredicate(Predicate); // For EQ and NE, we can always pick a value for the undef to make the // predicate pass or fail, so we can return undef. @@ -1750,9 +1749,9 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, if (!isa(GV) && !GV->hasExternalWeakLinkage() && !NullPointerIsDefined(nullptr /* F */, GV->getType()->getAddressSpace())) { - if (pred == ICmpInst::ICMP_EQ) + if (Predicate == ICmpInst::ICMP_EQ) return ConstantInt::getFalse(C1->getContext()); - else if (pred == ICmpInst::ICMP_NE) + else if (Predicate == ICmpInst::ICMP_NE) return ConstantInt::getTrue(C1->getContext()); } // icmp eq/ne(GV,null) -> false/true @@ -1762,9 +1761,9 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, if (!isa(GV) && !GV->hasExternalWeakLinkage() && !NullPointerIsDefined(nullptr /* F */, GV->getType()->getAddressSpace())) { - if (pred == ICmpInst::ICMP_EQ) + if (Predicate == ICmpInst::ICMP_EQ) return ConstantInt::getFalse(C1->getContext()); - else if (pred == ICmpInst::ICMP_NE) + else if (Predicate == ICmpInst::ICMP_NE) return ConstantInt::getTrue(C1->getContext()); } } @@ -1772,16 +1771,16 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, // The caller is expected to commute the operands if the constant expression // is C2. // C1 >= 0 --> true - if (pred == ICmpInst::ICMP_UGE) + if (Predicate == ICmpInst::ICMP_UGE) return Constant::getAllOnesValue(ResultTy); // C1 < 0 --> false - if (pred == ICmpInst::ICMP_ULT) + if (Predicate == ICmpInst::ICMP_ULT) return Constant::getNullValue(ResultTy); } // If the comparison is a comparison between two i1's, simplify it. if (C1->getType()->isIntegerTy(1)) { - switch(pred) { + switch (Predicate) { case ICmpInst::ICMP_EQ: if (isa(C2)) return ConstantExpr::getXor(C1, ConstantExpr::getNot(C2)); @@ -1796,12 +1795,10 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, if (isa(C1) && isa(C2)) { const APInt &V1 = cast(C1)->getValue(); const APInt &V2 = cast(C2)->getValue(); - return ConstantInt::get( - ResultTy, ICmpInst::compare(V1, V2, (ICmpInst::Predicate)pred)); + return ConstantInt::get(ResultTy, ICmpInst::compare(V1, V2, Predicate)); } else if (isa(C1) && isa(C2)) { const APFloat &C1V = cast(C1)->getValueAPF(); const APFloat &C2V = cast(C2)->getValueAPF(); - CmpInst::Predicate Predicate = CmpInst::Predicate(pred); return ConstantInt::get(ResultTy, FCmpInst::compare(C1V, C2V, Predicate)); } else if (auto *C1VTy = dyn_cast(C1->getType())) { @@ -1810,7 +1807,7 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, if (Constant *C2Splat = C2->getSplatValue()) return ConstantVector::getSplat( C1VTy->getElementCount(), - ConstantExpr::getCompare(pred, C1Splat, C2Splat)); + ConstantExpr::getCompare(Predicate, C1Splat, C2Splat)); // Do not iterate on scalable vector. The number of elements is unknown at // compile-time. @@ -1829,7 +1826,7 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, Constant *C2E = ConstantExpr::getExtractElement(C2, ConstantInt::get(Ty, I)); - ResElts.push_back(ConstantExpr::getCompare(pred, C1E, C2E)); + ResElts.push_back(ConstantExpr::getCompare(Predicate, C1E, C2E)); } return ConstantVector::get(ResElts); @@ -1854,46 +1851,52 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, case FCmpInst::BAD_FCMP_PREDICATE: break; // Couldn't determine anything about these constants. case FCmpInst::FCMP_OEQ: // We know that C1 == C2 - Result = (pred == FCmpInst::FCMP_UEQ || pred == FCmpInst::FCMP_OEQ || - pred == FCmpInst::FCMP_ULE || pred == FCmpInst::FCMP_OLE || - pred == FCmpInst::FCMP_UGE || pred == FCmpInst::FCMP_OGE); + Result = + (Predicate == FCmpInst::FCMP_UEQ || Predicate == FCmpInst::FCMP_OEQ || + Predicate == FCmpInst::FCMP_ULE || Predicate == FCmpInst::FCMP_OLE || + Predicate == FCmpInst::FCMP_UGE || Predicate == FCmpInst::FCMP_OGE); break; case FCmpInst::FCMP_OLT: // We know that C1 < C2 - Result = (pred == FCmpInst::FCMP_UNE || pred == FCmpInst::FCMP_ONE || - pred == FCmpInst::FCMP_ULT || pred == FCmpInst::FCMP_OLT || - pred == FCmpInst::FCMP_ULE || pred == FCmpInst::FCMP_OLE); + Result = + (Predicate == FCmpInst::FCMP_UNE || Predicate == FCmpInst::FCMP_ONE || + Predicate == FCmpInst::FCMP_ULT || Predicate == FCmpInst::FCMP_OLT || + Predicate == FCmpInst::FCMP_ULE || Predicate == FCmpInst::FCMP_OLE); break; case FCmpInst::FCMP_OGT: // We know that C1 > C2 - Result = (pred == FCmpInst::FCMP_UNE || pred == FCmpInst::FCMP_ONE || - pred == FCmpInst::FCMP_UGT || pred == FCmpInst::FCMP_OGT || - pred == FCmpInst::FCMP_UGE || pred == FCmpInst::FCMP_OGE); + Result = + (Predicate == FCmpInst::FCMP_UNE || Predicate == FCmpInst::FCMP_ONE || + Predicate == FCmpInst::FCMP_UGT || Predicate == FCmpInst::FCMP_OGT || + Predicate == FCmpInst::FCMP_UGE || Predicate == FCmpInst::FCMP_OGE); break; case FCmpInst::FCMP_OLE: // We know that C1 <= C2 // We can only partially decide this relation. - if (pred == FCmpInst::FCMP_UGT || pred == FCmpInst::FCMP_OGT) + if (Predicate == FCmpInst::FCMP_UGT || Predicate == FCmpInst::FCMP_OGT) Result = 0; - else if (pred == FCmpInst::FCMP_ULT || pred == FCmpInst::FCMP_OLT) + else if (Predicate == FCmpInst::FCMP_ULT || + Predicate == FCmpInst::FCMP_OLT) Result = 1; break; case FCmpInst::FCMP_OGE: // We known that C1 >= C2 // We can only partially decide this relation. - if (pred == FCmpInst::FCMP_ULT || pred == FCmpInst::FCMP_OLT) + if (Predicate == FCmpInst::FCMP_ULT || Predicate == FCmpInst::FCMP_OLT) Result = 0; - else if (pred == FCmpInst::FCMP_UGT || pred == FCmpInst::FCMP_OGT) + else if (Predicate == FCmpInst::FCMP_UGT || + Predicate == FCmpInst::FCMP_OGT) Result = 1; break; case FCmpInst::FCMP_ONE: // We know that C1 != C2 // We can only partially decide this relation. - if (pred == FCmpInst::FCMP_OEQ || pred == FCmpInst::FCMP_UEQ) + if (Predicate == FCmpInst::FCMP_OEQ || Predicate == FCmpInst::FCMP_UEQ) Result = 0; - else if (pred == FCmpInst::FCMP_ONE || pred == FCmpInst::FCMP_UNE) + else if (Predicate == FCmpInst::FCMP_ONE || + Predicate == FCmpInst::FCMP_UNE) Result = 1; break; case FCmpInst::FCMP_UEQ: // We know that C1 == C2 || isUnordered(C1, C2). // We can only partially decide this relation. - if (pred == FCmpInst::FCMP_ONE) + if (Predicate == FCmpInst::FCMP_ONE) Result = 0; - else if (pred == FCmpInst::FCMP_UEQ) + else if (Predicate == FCmpInst::FCMP_UEQ) Result = 1; break; } @@ -1905,67 +1908,84 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, } else { // Evaluate the relation between the two constants, per the predicate. int Result = -1; // -1 = unknown, 0 = known false, 1 = known true. - switch (evaluateICmpRelation(C1, C2, - CmpInst::isSigned((CmpInst::Predicate)pred))) { + switch (evaluateICmpRelation(C1, C2, CmpInst::isSigned(Predicate))) { default: llvm_unreachable("Unknown relational!"); case ICmpInst::BAD_ICMP_PREDICATE: break; // Couldn't determine anything about these constants. case ICmpInst::ICMP_EQ: // We know the constants are equal! // If we know the constants are equal, we can decide the result of this // computation precisely. - Result = ICmpInst::isTrueWhenEqual((ICmpInst::Predicate)pred); + Result = ICmpInst::isTrueWhenEqual(Predicate); break; case ICmpInst::ICMP_ULT: - switch (pred) { + switch (Predicate) { case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_NE: case ICmpInst::ICMP_ULE: Result = 1; break; case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_UGE: Result = 0; break; + default: + break; } break; case ICmpInst::ICMP_SLT: - switch (pred) { + switch (Predicate) { case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_NE: case ICmpInst::ICMP_SLE: Result = 1; break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_SGE: Result = 0; break; + default: + break; } break; case ICmpInst::ICMP_UGT: - switch (pred) { + switch (Predicate) { case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_NE: case ICmpInst::ICMP_UGE: Result = 1; break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_ULE: Result = 0; break; + default: + break; } break; case ICmpInst::ICMP_SGT: - switch (pred) { + switch (Predicate) { case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_NE: case ICmpInst::ICMP_SGE: Result = 1; break; case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_SLE: Result = 0; break; + default: + break; } break; case ICmpInst::ICMP_ULE: - if (pred == ICmpInst::ICMP_UGT) Result = 0; - if (pred == ICmpInst::ICMP_ULT || pred == ICmpInst::ICMP_ULE) Result = 1; + if (Predicate == ICmpInst::ICMP_UGT) + Result = 0; + if (Predicate == ICmpInst::ICMP_ULT || Predicate == ICmpInst::ICMP_ULE) + Result = 1; break; case ICmpInst::ICMP_SLE: - if (pred == ICmpInst::ICMP_SGT) Result = 0; - if (pred == ICmpInst::ICMP_SLT || pred == ICmpInst::ICMP_SLE) Result = 1; + if (Predicate == ICmpInst::ICMP_SGT) + Result = 0; + if (Predicate == ICmpInst::ICMP_SLT || Predicate == ICmpInst::ICMP_SLE) + Result = 1; break; case ICmpInst::ICMP_UGE: - if (pred == ICmpInst::ICMP_ULT) Result = 0; - if (pred == ICmpInst::ICMP_UGT || pred == ICmpInst::ICMP_UGE) Result = 1; + if (Predicate == ICmpInst::ICMP_ULT) + Result = 0; + if (Predicate == ICmpInst::ICMP_UGT || Predicate == ICmpInst::ICMP_UGE) + Result = 1; break; case ICmpInst::ICMP_SGE: - if (pred == ICmpInst::ICMP_SLT) Result = 0; - if (pred == ICmpInst::ICMP_SGT || pred == ICmpInst::ICMP_SGE) Result = 1; + if (Predicate == ICmpInst::ICMP_SLT) + Result = 0; + if (Predicate == ICmpInst::ICMP_SGT || Predicate == ICmpInst::ICMP_SGE) + Result = 1; break; case ICmpInst::ICMP_NE: - if (pred == ICmpInst::ICMP_EQ) Result = 0; - if (pred == ICmpInst::ICMP_NE) Result = 1; + if (Predicate == ICmpInst::ICMP_EQ) + Result = 0; + if (Predicate == ICmpInst::ICMP_NE) + Result = 1; break; } @@ -1983,16 +2003,16 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, CE2->getType()->isVectorTy() == CE2Op0->getType()->isVectorTy() && !CE2Op0->getType()->isFPOrFPVectorTy()) { Constant *Inverse = ConstantExpr::getBitCast(C1, CE2Op0->getType()); - return ConstantExpr::getICmp(pred, Inverse, CE2Op0); + return ConstantExpr::getICmp(Predicate, Inverse, CE2Op0); } } // If the left hand side is an extension, try eliminating it. if (ConstantExpr *CE1 = dyn_cast(C1)) { if ((CE1->getOpcode() == Instruction::SExt && - ICmpInst::isSigned((ICmpInst::Predicate)pred)) || + ICmpInst::isSigned(Predicate)) || (CE1->getOpcode() == Instruction::ZExt && - !ICmpInst::isSigned((ICmpInst::Predicate)pred))){ + !ICmpInst::isSigned(Predicate))) { Constant *CE1Op0 = CE1->getOperand(0); Constant *CE1Inverse = ConstantExpr::getTrunc(CE1, CE1Op0->getType()); if (CE1Inverse == CE1Op0) { @@ -2000,7 +2020,7 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, Constant *C2Inverse = ConstantExpr::getTrunc(C2, CE1Op0->getType()); if (ConstantExpr::getCast(CE1->getOpcode(), C2Inverse, C2->getType()) == C2) - return ConstantExpr::getICmp(pred, CE1Inverse, C2Inverse); + return ConstantExpr::getICmp(Predicate, CE1Inverse, C2Inverse); } } } @@ -2010,8 +2030,8 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, // If C2 is a constant expr and C1 isn't, flip them around and fold the // other way if possible. // Also, if C1 is null and C2 isn't, flip them around. - pred = ICmpInst::getSwappedPredicate((ICmpInst::Predicate)pred); - return ConstantExpr::getICmp(pred, C2, C1); + Predicate = ICmpInst::getSwappedPredicate(Predicate); + return ConstantExpr::getICmp(Predicate, C2, C1); } } return nullptr; diff --git a/llvm/lib/IR/ConstantFold.h b/llvm/lib/IR/ConstantFold.h index 0cdd5cf3cbce67..1aa44f4d21e58c 100644 --- a/llvm/lib/IR/ConstantFold.h +++ b/llvm/lib/IR/ConstantFold.h @@ -19,6 +19,7 @@ #define LLVM_LIB_IR_CONSTANTFOLD_H #include "llvm/ADT/Optional.h" +#include "llvm/IR/InstrTypes.h" namespace llvm { template class ArrayRef; @@ -46,7 +47,7 @@ template class ArrayRef; Constant *ConstantFoldUnaryInstruction(unsigned Opcode, Constant *V); Constant *ConstantFoldBinaryInstruction(unsigned Opcode, Constant *V1, Constant *V2); - Constant *ConstantFoldCompareInstruction(unsigned short predicate, + Constant *ConstantFoldCompareInstruction(CmpInst::Predicate Predicate, Constant *C1, Constant *C2); Constant *ConstantFoldGetElementPtr(Type *Ty, Constant *C, bool InBounds, Optional InRangeIndex, diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index 837be910f6d815..e753fc7a38710a 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -2546,11 +2546,11 @@ Constant *ConstantExpr::getGetElementPtr(Type *Ty, Constant *C, Constant *ConstantExpr::getICmp(unsigned short pred, Constant *LHS, Constant *RHS, bool OnlyIfReduced) { + auto Predicate = static_cast(pred); assert(LHS->getType() == RHS->getType()); - assert(CmpInst::isIntPredicate((CmpInst::Predicate)pred) && - "Invalid ICmp Predicate"); + assert(CmpInst::isIntPredicate(Predicate) && "Invalid ICmp Predicate"); - if (Constant *FC = ConstantFoldCompareInstruction(pred, LHS, RHS)) + if (Constant *FC = ConstantFoldCompareInstruction(Predicate, LHS, RHS)) return FC; // Fold a few common cases... if (OnlyIfReduced) @@ -2559,7 +2559,7 @@ Constant *ConstantExpr::getICmp(unsigned short pred, Constant *LHS, // Look up the constant in the table first to ensure uniqueness Constant *ArgVec[] = { LHS, RHS }; // Get the key type with both the opcode and predicate - const ConstantExprKeyType Key(Instruction::ICmp, ArgVec, pred); + const ConstantExprKeyType Key(Instruction::ICmp, ArgVec, Predicate); Type *ResultTy = Type::getInt1Ty(LHS->getContext()); if (VectorType *VT = dyn_cast(LHS->getType())) @@ -2571,11 +2571,11 @@ Constant *ConstantExpr::getICmp(unsigned short pred, Constant *LHS, Constant *ConstantExpr::getFCmp(unsigned short pred, Constant *LHS, Constant *RHS, bool OnlyIfReduced) { + auto Predicate = static_cast(pred); assert(LHS->getType() == RHS->getType()); - assert(CmpInst::isFPPredicate((CmpInst::Predicate)pred) && - "Invalid FCmp Predicate"); + assert(CmpInst::isFPPredicate(Predicate) && "Invalid FCmp Predicate"); - if (Constant *FC = ConstantFoldCompareInstruction(pred, LHS, RHS)) + if (Constant *FC = ConstantFoldCompareInstruction(Predicate, LHS, RHS)) return FC; // Fold a few common cases... if (OnlyIfReduced) @@ -2584,7 +2584,7 @@ Constant *ConstantExpr::getFCmp(unsigned short pred, Constant *LHS, // Look up the constant in the table first to ensure uniqueness Constant *ArgVec[] = { LHS, RHS }; // Get the key type with both the opcode and predicate - const ConstantExprKeyType Key(Instruction::FCmp, ArgVec, pred); + const ConstantExprKeyType Key(Instruction::FCmp, ArgVec, Predicate); Type *ResultTy = Type::getInt1Ty(LHS->getContext()); if (VectorType *VT = dyn_cast(LHS->getType()))