diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index 5b331e4444915..b024e8a68bd6e 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1718,16 +1718,6 @@ class SelectionDAG { /// the target's desired shift amount type. LLVM_ABI SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op); - /// Expands a node with multiple results to an FP or vector libcall. The - /// libcall is expected to take all the operands of the \p Node followed by - /// output pointers for each of the results. \p CallRetResNo can be optionally - /// set to indicate that one of the results comes from the libcall's return - /// value. - LLVM_ABI bool - expandMultipleResultFPLibCall(RTLIB::Libcall LC, SDNode *Node, - SmallVectorImpl &Results, - std::optional CallRetResNo = {}); - /// Expand the specified \c ISD::VAARG node as the Legalize pass would. LLVM_ABI SDValue expandVAArg(SDNode *Node); diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 4d5d1fc7dfadc..cec7d09f494d6 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -5757,6 +5757,16 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase { /// consisting of zext/sext, extract_subvector, mul and add operations. SDValue expandPartialReduceMLA(SDNode *Node, SelectionDAG &DAG) const; + /// Expands a node with multiple results to an FP or vector libcall. The + /// libcall is expected to take all the operands of the \p Node followed by + /// output pointers for each of the results. \p CallRetResNo can be optionally + /// set to indicate that one of the results comes from the libcall's return + /// value. + bool expandMultipleResultFPLibCall( + SelectionDAG &DAG, RTLIB::Libcall LC, SDNode *Node, + SmallVectorImpl &Results, + std::optional CallRetResNo = {}) const; + /// Legalize a SETCC or VP_SETCC with given LHS and RHS and condition code CC /// on the current target. A VP_SETCC will additionally be given a Mask /// and/or EVL not equal to SDValue(). diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp index 3ed84af6a8717..99d14a60c6ed1 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -4842,7 +4842,7 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) { RTLIB::Libcall LC = Node->getOpcode() == ISD::FSINCOS ? RTLIB::getSINCOS(VT) : RTLIB::getSINCOSPI(VT); - bool Expanded = DAG.expandMultipleResultFPLibCall(LC, Node, Results); + bool Expanded = TLI.expandMultipleResultFPLibCall(DAG, LC, Node, Results); if (!Expanded) { DAG.getContext()->emitError(Twine("no libcall available for ") + Node->getOperationName(&DAG)); @@ -4940,7 +4940,7 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) { EVT VT = Node->getValueType(0); RTLIB::Libcall LC = Node->getOpcode() == ISD::FMODF ? RTLIB::getMODF(VT) : RTLIB::getFREXP(VT); - bool Expanded = DAG.expandMultipleResultFPLibCall(LC, Node, Results, + bool Expanded = TLI.expandMultipleResultFPLibCall(DAG, LC, Node, Results, /*CallRetResNo=*/0); if (!Expanded) llvm_unreachable("Expected scalar FFREXP/FMODF to expand to libcall!"); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp index 58983cb57d7f6..383a025a4d916 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp @@ -1726,7 +1726,7 @@ void DAGTypeLegalizer::ExpandFloatRes_UnaryWithTwoFPResults( SDNode *N, RTLIB::Libcall LC, std::optional CallRetResNo) { assert(!N->isStrictFPOpcode() && "strictfp not implemented"); SmallVector Results; - DAG.expandMultipleResultFPLibCall(LC, N, Results, CallRetResNo); + TLI.expandMultipleResultFPLibCall(DAG, LC, N, Results, CallRetResNo); for (auto [ResNo, Res] : enumerate(Results)) { SDValue Lo, Hi; GetPairElements(Res, Lo, Hi); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index c55e55df373e9..7d979caa8bf82 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -1275,7 +1275,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl &Results) { ? RTLIB::getSINCOS(VT) : RTLIB::getSINCOSPI(VT); if (LC != RTLIB::UNKNOWN_LIBCALL && - DAG.expandMultipleResultFPLibCall(LC, Node, Results)) + TLI.expandMultipleResultFPLibCall(DAG, LC, Node, Results)) return; // TODO: Try to see if there's a narrower call available to use before @@ -1286,7 +1286,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl &Results) { EVT VT = Node->getValueType(0); RTLIB::Libcall LC = RTLIB::getMODF(VT); if (LC != RTLIB::UNKNOWN_LIBCALL && - DAG.expandMultipleResultFPLibCall(LC, Node, Results, + TLI.expandMultipleResultFPLibCall(DAG, LC, Node, Results, /*CallRetResNo=*/0)) return; break; diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index f05266967fb68..363c71d84694f 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -2467,167 +2467,6 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) { return getZExtOrTrunc(Op, SDLoc(Op), ShTy); } -/// Given a store node \p StoreNode, return true if it is safe to fold that node -/// into \p FPNode, which expands to a library call with output pointers. -static bool canFoldStoreIntoLibCallOutputPointers(StoreSDNode *StoreNode, - SDNode *FPNode) { - SmallVector Worklist; - SmallVector DeferredNodes; - SmallPtrSet Visited; - - // Skip FPNode use by StoreNode (that's the use we want to fold into FPNode). - for (SDValue Op : StoreNode->ops()) - if (Op.getNode() != FPNode) - Worklist.push_back(Op.getNode()); - - unsigned MaxSteps = SelectionDAG::getHasPredecessorMaxSteps(); - while (!Worklist.empty()) { - const SDNode *Node = Worklist.pop_back_val(); - auto [_, Inserted] = Visited.insert(Node); - if (!Inserted) - continue; - - if (MaxSteps > 0 && Visited.size() >= MaxSteps) - return false; - - // Reached the FPNode (would result in a cycle). - // OR Reached CALLSEQ_START (would result in nested call sequences). - if (Node == FPNode || Node->getOpcode() == ISD::CALLSEQ_START) - return false; - - if (Node->getOpcode() == ISD::CALLSEQ_END) { - // Defer looking into call sequences (so we can check we're outside one). - // We still need to look through these for the predecessor check. - DeferredNodes.push_back(Node); - continue; - } - - for (SDValue Op : Node->ops()) - Worklist.push_back(Op.getNode()); - } - - // True if we're outside a call sequence and don't have the FPNode as a - // predecessor. No cycles or nested call sequences possible. - return !SDNode::hasPredecessorHelper(FPNode, Visited, DeferredNodes, - MaxSteps); -} - -bool SelectionDAG::expandMultipleResultFPLibCall( - RTLIB::Libcall LC, SDNode *Node, SmallVectorImpl &Results, - std::optional CallRetResNo) { - if (LC == RTLIB::UNKNOWN_LIBCALL) - return false; - - RTLIB::LibcallImpl LibcallImpl = TLI->getLibcallImpl(LC); - if (LibcallImpl == RTLIB::Unsupported) - return false; - - LLVMContext &Ctx = *getContext(); - EVT VT = Node->getValueType(0); - unsigned NumResults = Node->getNumValues(); - - // Find users of the node that store the results (and share input chains). The - // destination pointers can be used instead of creating stack allocations. - SDValue StoresInChain; - SmallVector ResultStores(NumResults); - for (SDNode *User : Node->users()) { - if (!ISD::isNormalStore(User)) - continue; - auto *ST = cast(User); - SDValue StoreValue = ST->getValue(); - unsigned ResNo = StoreValue.getResNo(); - // Ensure the store corresponds to an output pointer. - if (CallRetResNo == ResNo) - continue; - // Ensure the store to the default address space and not atomic or volatile. - if (!ST->isSimple() || ST->getAddressSpace() != 0) - continue; - // Ensure all store chains are the same (so they don't alias). - if (StoresInChain && ST->getChain() != StoresInChain) - continue; - // Ensure the store is properly aligned. - Type *StoreType = StoreValue.getValueType().getTypeForEVT(Ctx); - if (ST->getAlign() < - getDataLayout().getABITypeAlign(StoreType->getScalarType())) - continue; - // Avoid: - // 1. Creating cyclic dependencies. - // 2. Expanding the node to a call within a call sequence. - if (!canFoldStoreIntoLibCallOutputPointers(ST, Node)) - continue; - ResultStores[ResNo] = ST; - StoresInChain = ST->getChain(); - } - - TargetLowering::ArgListTy Args; - - // Pass the arguments. - for (const SDValue &Op : Node->op_values()) { - EVT ArgVT = Op.getValueType(); - Type *ArgTy = ArgVT.getTypeForEVT(Ctx); - Args.emplace_back(Op, ArgTy); - } - - // Pass the output pointers. - SmallVector ResultPtrs(NumResults); - Type *PointerTy = PointerType::getUnqual(Ctx); - for (auto [ResNo, ST] : llvm::enumerate(ResultStores)) { - if (ResNo == CallRetResNo) - continue; - EVT ResVT = Node->getValueType(ResNo); - SDValue ResultPtr = ST ? ST->getBasePtr() : CreateStackTemporary(ResVT); - ResultPtrs[ResNo] = ResultPtr; - Args.emplace_back(ResultPtr, PointerTy); - } - - SDLoc DL(Node); - - if (RTLIB::RuntimeLibcallsInfo::hasVectorMaskArgument(LibcallImpl)) { - // Pass the vector mask (if required). - EVT MaskVT = TLI->getSetCCResultType(getDataLayout(), Ctx, VT); - SDValue Mask = getBoolConstant(true, DL, MaskVT, VT); - Args.emplace_back(Mask, MaskVT.getTypeForEVT(Ctx)); - } - - Type *RetType = CallRetResNo.has_value() - ? Node->getValueType(*CallRetResNo).getTypeForEVT(Ctx) - : Type::getVoidTy(Ctx); - SDValue InChain = StoresInChain ? StoresInChain : getEntryNode(); - SDValue Callee = - getExternalSymbol(TLI->getLibcallImplName(LibcallImpl).data(), - TLI->getPointerTy(getDataLayout())); - TargetLowering::CallLoweringInfo CLI(*this); - CLI.setDebugLoc(DL).setChain(InChain).setLibCallee( - TLI->getLibcallImplCallingConv(LibcallImpl), RetType, Callee, - std::move(Args)); - - auto [Call, CallChain] = TLI->LowerCallTo(CLI); - - for (auto [ResNo, ResultPtr] : llvm::enumerate(ResultPtrs)) { - if (ResNo == CallRetResNo) { - Results.push_back(Call); - continue; - } - MachinePointerInfo PtrInfo; - SDValue LoadResult = - getLoad(Node->getValueType(ResNo), DL, CallChain, ResultPtr, PtrInfo); - SDValue OutChain = LoadResult.getValue(1); - - if (StoreSDNode *ST = ResultStores[ResNo]) { - // Replace store with the library call. - ReplaceAllUsesOfValueWith(SDValue(ST, 0), OutChain); - PtrInfo = ST->getPointerInfo(); - } else { - PtrInfo = MachinePointerInfo::getFixedStack( - getMachineFunction(), cast(ResultPtr)->getIndex()); - } - - Results.push_back(LoadResult); - } - - return true; -} - SDValue SelectionDAG::expandVAArg(SDNode *Node) { SDLoc dl(Node); const TargetLowering &TLI = getTargetLoweringInfo(); diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index b51d6649af2ec..bb64f4ee70280 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -12126,6 +12126,167 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N, return Subvectors[0]; } +/// Given a store node \p StoreNode, return true if it is safe to fold that node +/// into \p FPNode, which expands to a library call with output pointers. +static bool canFoldStoreIntoLibCallOutputPointers(StoreSDNode *StoreNode, + SDNode *FPNode) { + SmallVector Worklist; + SmallVector DeferredNodes; + SmallPtrSet Visited; + + // Skip FPNode use by StoreNode (that's the use we want to fold into FPNode). + for (SDValue Op : StoreNode->ops()) + if (Op.getNode() != FPNode) + Worklist.push_back(Op.getNode()); + + unsigned MaxSteps = SelectionDAG::getHasPredecessorMaxSteps(); + while (!Worklist.empty()) { + const SDNode *Node = Worklist.pop_back_val(); + auto [_, Inserted] = Visited.insert(Node); + if (!Inserted) + continue; + + if (MaxSteps > 0 && Visited.size() >= MaxSteps) + return false; + + // Reached the FPNode (would result in a cycle). + // OR Reached CALLSEQ_START (would result in nested call sequences). + if (Node == FPNode || Node->getOpcode() == ISD::CALLSEQ_START) + return false; + + if (Node->getOpcode() == ISD::CALLSEQ_END) { + // Defer looking into call sequences (so we can check we're outside one). + // We still need to look through these for the predecessor check. + DeferredNodes.push_back(Node); + continue; + } + + for (SDValue Op : Node->ops()) + Worklist.push_back(Op.getNode()); + } + + // True if we're outside a call sequence and don't have the FPNode as a + // predecessor. No cycles or nested call sequences possible. + return !SDNode::hasPredecessorHelper(FPNode, Visited, DeferredNodes, + MaxSteps); +} + +bool TargetLowering::expandMultipleResultFPLibCall( + SelectionDAG &DAG, RTLIB::Libcall LC, SDNode *Node, + SmallVectorImpl &Results, + std::optional CallRetResNo) const { + if (LC == RTLIB::UNKNOWN_LIBCALL) + return false; + + RTLIB::LibcallImpl LibcallImpl = getLibcallImpl(LC); + if (LibcallImpl == RTLIB::Unsupported) + return false; + + LLVMContext &Ctx = *DAG.getContext(); + EVT VT = Node->getValueType(0); + unsigned NumResults = Node->getNumValues(); + + // Find users of the node that store the results (and share input chains). The + // destination pointers can be used instead of creating stack allocations. + SDValue StoresInChain; + SmallVector ResultStores(NumResults); + for (SDNode *User : Node->users()) { + if (!ISD::isNormalStore(User)) + continue; + auto *ST = cast(User); + SDValue StoreValue = ST->getValue(); + unsigned ResNo = StoreValue.getResNo(); + // Ensure the store corresponds to an output pointer. + if (CallRetResNo == ResNo) + continue; + // Ensure the store to the default address space and not atomic or volatile. + if (!ST->isSimple() || ST->getAddressSpace() != 0) + continue; + // Ensure all store chains are the same (so they don't alias). + if (StoresInChain && ST->getChain() != StoresInChain) + continue; + // Ensure the store is properly aligned. + Type *StoreType = StoreValue.getValueType().getTypeForEVT(Ctx); + if (ST->getAlign() < + DAG.getDataLayout().getABITypeAlign(StoreType->getScalarType())) + continue; + // Avoid: + // 1. Creating cyclic dependencies. + // 2. Expanding the node to a call within a call sequence. + if (!canFoldStoreIntoLibCallOutputPointers(ST, Node)) + continue; + ResultStores[ResNo] = ST; + StoresInChain = ST->getChain(); + } + + ArgListTy Args; + + // Pass the arguments. + for (const SDValue &Op : Node->op_values()) { + EVT ArgVT = Op.getValueType(); + Type *ArgTy = ArgVT.getTypeForEVT(Ctx); + Args.emplace_back(Op, ArgTy); + } + + // Pass the output pointers. + SmallVector ResultPtrs(NumResults); + Type *PointerTy = PointerType::getUnqual(Ctx); + for (auto [ResNo, ST] : llvm::enumerate(ResultStores)) { + if (ResNo == CallRetResNo) + continue; + EVT ResVT = Node->getValueType(ResNo); + SDValue ResultPtr = ST ? ST->getBasePtr() : DAG.CreateStackTemporary(ResVT); + ResultPtrs[ResNo] = ResultPtr; + Args.emplace_back(ResultPtr, PointerTy); + } + + SDLoc DL(Node); + + if (RTLIB::RuntimeLibcallsInfo::hasVectorMaskArgument(LibcallImpl)) { + // Pass the vector mask (if required). + EVT MaskVT = getSetCCResultType(DAG.getDataLayout(), Ctx, VT); + SDValue Mask = DAG.getBoolConstant(true, DL, MaskVT, VT); + Args.emplace_back(Mask, MaskVT.getTypeForEVT(Ctx)); + } + + Type *RetType = CallRetResNo.has_value() + ? Node->getValueType(*CallRetResNo).getTypeForEVT(Ctx) + : Type::getVoidTy(Ctx); + SDValue InChain = StoresInChain ? StoresInChain : DAG.getEntryNode(); + SDValue Callee = DAG.getExternalSymbol(getLibcallImplName(LibcallImpl).data(), + getPointerTy(DAG.getDataLayout())); + TargetLowering::CallLoweringInfo CLI(DAG); + CLI.setDebugLoc(DL).setChain(InChain).setLibCallee( + getLibcallImplCallingConv(LibcallImpl), RetType, Callee, std::move(Args)); + + auto [Call, CallChain] = LowerCallTo(CLI); + + for (auto [ResNo, ResultPtr] : llvm::enumerate(ResultPtrs)) { + if (ResNo == CallRetResNo) { + Results.push_back(Call); + continue; + } + MachinePointerInfo PtrInfo; + SDValue LoadResult = DAG.getLoad(Node->getValueType(ResNo), DL, CallChain, + ResultPtr, PtrInfo); + SDValue OutChain = LoadResult.getValue(1); + + if (StoreSDNode *ST = ResultStores[ResNo]) { + // Replace store with the library call. + DAG.ReplaceAllUsesOfValueWith(SDValue(ST, 0), OutChain); + PtrInfo = ST->getPointerInfo(); + } else { + PtrInfo = MachinePointerInfo::getFixedStack( + DAG.getMachineFunction(), + cast(ResultPtr)->getIndex()); + } + + Results.push_back(LoadResult); + } + + return true; +} + bool TargetLowering::LegalizeSetCCCondCode(SelectionDAG &DAG, EVT VT, SDValue &LHS, SDValue &RHS, SDValue &CC, SDValue Mask,