diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index a70c9c02d3a83..ede9ddefec4a8 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -3576,6 +3576,10 @@ class LLVM_ABI TargetLoweringBase { return nullptr; } + const RTLIB::RuntimeLibcallsInfo &getRuntimeLibcallsInfo() const { + return Libcalls; + } + void setLibcallImpl(RTLIB::Libcall Call, RTLIB::LibcallImpl Impl) { Libcalls.setLibcallImpl(Call, Impl); } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 018ef31a7b5b8..e0ba20df9eb28 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -9002,12 +9002,12 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI, } static SMECallAttrs -getSMECallAttrs(const Function &Caller, const AArch64TargetLowering &TLI, +getSMECallAttrs(const Function &Caller, const RTLIB::RuntimeLibcallsInfo &RTLCI, const TargetLowering::CallLoweringInfo &CLI) { if (CLI.CB) - return SMECallAttrs(*CLI.CB, &TLI); + return SMECallAttrs(*CLI.CB, &RTLCI); if (auto *ES = dyn_cast(CLI.Callee)) - return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol(), TLI)); + return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol(), RTLCI)); return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(SMEAttrs::Normal)); } @@ -9029,7 +9029,8 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization( // SME Streaming functions are not eligible for TCO as they may require // the streaming mode or ZA to be restored after returning from the call. - SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, *this, CLI); + SMECallAttrs CallAttrs = + getSMECallAttrs(CallerF, getRuntimeLibcallsInfo(), CLI); if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() || CallAttrs.requiresPreservingAllZAState() || CallAttrs.caller().hasStreamingBody()) @@ -9454,7 +9455,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, } // Determine whether we need any streaming mode changes. - SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI); + SMECallAttrs CallAttrs = + getSMECallAttrs(MF.getFunction(), getRuntimeLibcallsInfo(), CLI); std::optional ZAMarkerNode; bool UseNewSMEABILowering = getTM().useNewSMEABILowering(); @@ -29818,7 +29820,7 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const { // Checks to allow the use of SME instructions if (auto *Base = dyn_cast(&Inst)) { - auto CallAttrs = SMECallAttrs(*Base, this); + auto CallAttrs = SMECallAttrs(*Base, &getRuntimeLibcallsInfo()); if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() || CallAttrs.requiresPreservingZT0() || CallAttrs.requiresPreservingAllZAState()) diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index d50af118d7320..fede586cf35bc 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -224,7 +224,8 @@ static cl::opt EnableScalableAutovecInStreamingMode( static bool isSMEABIRoutineCall(const CallInst &CI, const AArch64TargetLowering &TLI) { const auto *F = CI.getCalledFunction(); - return F && SMEAttrs(F->getName(), TLI).isSMEABIRoutine(); + return F && + SMEAttrs(F->getName(), TLI.getRuntimeLibcallsInfo()).isSMEABIRoutine(); } /// Returns true if the function has explicit operations that can only be @@ -355,7 +356,7 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call, // change only once and avoid inlining of G into F. SMEAttrs FAttrs(*F); - SMECallAttrs CallAttrs(Call, getTLI()); + SMECallAttrs CallAttrs(Call, &getTLI()->getRuntimeLibcallsInfo()); if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) { if (F == Call.getCaller()) // (1) diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp index d71f7280597b9..92e09eed92a6b 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -75,8 +75,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) { } void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName, - const AArch64TargetLowering &TLI) { - RTLIB::LibcallImpl Impl = TLI.getSupportedLibcallImpl(FuncName); + const RTLIB::RuntimeLibcallsInfo &RTLCI) { + RTLIB::LibcallImpl Impl = RTLCI.getSupportedLibcallImpl(FuncName); if (Impl == RTLIB::Unsupported) return; unsigned KnownAttrs = SMEAttrs::Normal; @@ -124,11 +124,12 @@ bool SMECallAttrs::requiresSMChange() const { return true; } -SMECallAttrs::SMECallAttrs(const CallBase &CB, const AArch64TargetLowering *TLI) +SMECallAttrs::SMECallAttrs(const CallBase &CB, + const RTLIB::RuntimeLibcallsInfo *RTLCI) : CallerFn(*CB.getFunction()), CalledFn(SMEAttrs::Normal), Callsite(CB.getAttributes()), IsIndirect(CB.isIndirectCall()) { if (auto *CalledFunction = CB.getCalledFunction()) - CalledFn = SMEAttrs(*CalledFunction, TLI); + CalledFn = SMEAttrs(*CalledFunction, RTLCI); // An `invoke` of an agnostic ZA function may not return normally (it may // resume in an exception block). In this case, it acts like a private ZA diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h index d26e3cd3a9f76..28c397e221fdc 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h @@ -12,8 +12,9 @@ #include "llvm/IR/Function.h" namespace llvm { - -class AArch64TargetLowering; +namespace RTLIB { +struct RuntimeLibcallsInfo; +} class Function; class CallBase; @@ -52,14 +53,14 @@ class SMEAttrs { SMEAttrs() = default; SMEAttrs(unsigned Mask) { set(Mask); } - SMEAttrs(const Function &F, const AArch64TargetLowering *TLI = nullptr) + SMEAttrs(const Function &F, const RTLIB::RuntimeLibcallsInfo *RTLCI = nullptr) : SMEAttrs(F.getAttributes()) { - if (TLI) - addKnownFunctionAttrs(F.getName(), *TLI); + if (RTLCI) + addKnownFunctionAttrs(F.getName(), *RTLCI); } SMEAttrs(const AttributeList &L); - SMEAttrs(StringRef FuncName, const AArch64TargetLowering &TLI) { - addKnownFunctionAttrs(FuncName, TLI); + SMEAttrs(StringRef FuncName, const RTLIB::RuntimeLibcallsInfo &RTLCI) { + addKnownFunctionAttrs(FuncName, RTLCI); }; void set(unsigned M, bool Enable = true) { @@ -157,7 +158,7 @@ class SMEAttrs { private: void addKnownFunctionAttrs(StringRef FuncName, - const AArch64TargetLowering &TLI); + const RTLIB::RuntimeLibcallsInfo &RTLCI); void validate() const; }; @@ -175,7 +176,7 @@ class SMECallAttrs { SMEAttrs Callsite = SMEAttrs::Normal) : CallerFn(Caller), CalledFn(Callee), Callsite(Callsite) {} - SMECallAttrs(const CallBase &CB, const AArch64TargetLowering *TLI); + SMECallAttrs(const CallBase &CB, const RTLIB::RuntimeLibcallsInfo *RTLCI); SMEAttrs &caller() { return CallerFn; } SMEAttrs &callee() { return IsIndirect ? Callsite : CalledFn; }