diff --git a/llvm/lib/CodeGen/ExpandFp.cpp b/llvm/lib/CodeGen/ExpandFp.cpp index 9cc6c6a706c58..291a6447fe36c 100644 --- a/llvm/lib/CodeGen/ExpandFp.cpp +++ b/llvm/lib/CodeGen/ExpandFp.cpp @@ -74,11 +74,64 @@ class FRemExpander { /// Constant 1 of type \p ExTy. Value *One; + /// The frem argument/return types that can be expanded by this class. + // TODO The expansion could work for other floating point types + // as well, but this would require additional testing. + static constexpr std::array ExpandableTypes{MVT::f16, MVT::f32, + MVT::f64}; + + /// Libcalls for frem instructions of the type at the corresponding + /// positions of ExpandableTypes. + static constexpr std::array FremLibcalls{ + RTLIB::REM_F32, RTLIB::REM_F32, RTLIB::REM_F64}; + + /// Return the Libcall for frem instructions of expandable type \p VT or + /// std::nullopt if \p VT is not expandable. + static std::optional getFremLibcallForType(EVT VT) { + MVT V = VT.getSimpleVT(); + for (unsigned I = 0; I < ExpandableTypes.size(); I++) + if (ExpandableTypes[I] == V) + return FremLibcalls[I]; + + return {}; + }; + public: static bool canExpandType(Type *Ty) { - // TODO The expansion should work for other floating point types - // as well, but this would require additional testing. - return Ty->isIEEELikeFPTy() && !Ty->isBFloatTy() && !Ty->isFP128Ty(); + EVT VT = EVT::getEVT(Ty); + assert(VT.isSimple() && "Can expand only simple types"); + + return find(ExpandableTypes, VT.getSimpleVT()) != ExpandableTypes.end(); + } + + /// Return true if the pass should expand a frem instruction of the + /// given \p Ty for the target represented by \p TLI. Expansion + /// should happen if the legalization for the scalar type uses a + /// non-existing libcall. + static bool shouldExpandFremType(const TargetLowering &TLI, EVT VT) { + assert(!VT.isVector() && "Cannot handle vector type; must scalarize first"); + + TargetLowering::LegalizeAction LA = TLI.getOperationAction(ISD::FREM, VT); + if (LA != TargetLowering::LegalizeAction::Expand) + return false; + + auto Libcall = getFremLibcallForType(VT); + return Libcall.has_value() && !TLI.getLibcallName(*Libcall); + } + + static bool shouldExpandFremType(const TargetLowering &TLI, Type *Ty) { + // Consider scalar type for simplicity. + // It is very unlikely that a vector type can be legalized without a libcall + // if the scalar type cannot. + return shouldExpandFremType(TLI, EVT::getEVT(Ty->getScalarType())); + } + + /// Return true if the pass should expand "frem" instructions of some any for + /// the target represented by \p TLI. + static bool shouldExpandAnyFremType(const TargetLowering &TLI) { + return std::any_of( + ExpandableTypes.begin(), ExpandableTypes.end(), + [&](MVT V) { return shouldExpandFremType(TLI, EVT(V)); }); } static FRemExpander create(IRBuilder<> &B, Type *Ty) { @@ -959,36 +1012,6 @@ static void scalarize(Instruction *I, SmallVectorImpl &Replace) { I->eraseFromParent(); } -// This covers all floating point types; more than we need here. -// TODO Move somewhere else for general use? -/// Return the Libcall for a frem instruction of -/// type \p Ty. -static RTLIB::Libcall fremToLibcall(Type *Ty) { - assert(Ty->isFloatingPointTy()); - if (Ty->isFloatTy() || Ty->is16bitFPTy()) - return RTLIB::REM_F32; - if (Ty->isDoubleTy()) - return RTLIB::REM_F64; - if (Ty->isFP128Ty()) - return RTLIB::REM_F128; - if (Ty->isX86_FP80Ty()) - return RTLIB::REM_F80; - if (Ty->isPPC_FP128Ty()) - return RTLIB::REM_PPCF128; - - llvm_unreachable("Unknown floating point type"); -} - -/* Return true if, according to \p LibInfo, the target either directly - supports the frem instruction for the \p Ty, has a custom lowering, - or uses a libcall. */ -static bool targetSupportsFrem(const TargetLowering &TLI, Type *Ty) { - if (!TLI.isOperationExpand(ISD::FREM, EVT::getEVT(Ty))) - return true; - - return TLI.getLibcallName(fremToLibcall(Ty->getScalarType())); -} - static bool runImpl(Function &F, const TargetLowering &TLI, AssumptionCache *AC) { SmallVector Replace; @@ -1000,19 +1023,25 @@ static bool runImpl(Function &F, const TargetLowering &TLI, if (ExpandFpConvertBits != llvm::IntegerType::MAX_INT_BITS) MaxLegalFpConvertBitWidth = ExpandFpConvertBits; - if (MaxLegalFpConvertBitWidth >= llvm::IntegerType::MAX_INT_BITS) + bool DisableExpandLargeFp = + MaxLegalFpConvertBitWidth >= llvm::IntegerType::MAX_INT_BITS; + bool DisableFrem = !FRemExpander::shouldExpandAnyFremType(TLI); + + if (DisableExpandLargeFp && DisableFrem) return false; for (auto &I : instructions(F)) { switch (I.getOpcode()) { case Instruction::FRem: { + if (DisableFrem) + continue; + Type *Ty = I.getType(); // TODO: This pass doesn't handle scalable vectors. if (Ty->isScalableTy()) continue; - if (targetSupportsFrem(TLI, Ty) || - !FRemExpander::canExpandType(Ty->getScalarType())) + if (!FRemExpander::shouldExpandFremType(TLI, Ty)) continue; Replace.push_back(&I); @@ -1022,6 +1051,9 @@ static bool runImpl(Function &F, const TargetLowering &TLI, } case Instruction::FPToUI: case Instruction::FPToSI: { + if (DisableExpandLargeFp) + continue; + // TODO: This pass doesn't handle scalable vectors. if (I.getOperand(0)->getType()->isScalableTy()) continue; @@ -1039,6 +1071,9 @@ static bool runImpl(Function &F, const TargetLowering &TLI, } case Instruction::UIToFP: case Instruction::SIToFP: { + if (DisableExpandLargeFp) + continue; + // TODO: This pass doesn't handle scalable vectors. if (I.getOperand(0)->getType()->isScalableTy()) continue;