diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 5ffaf2c49b4c0..b90b7f50856b0 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -9312,6 +9312,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, std::optional ZAMarkerNode; bool UseNewSMEABILowering = getTM().useNewSMEABILowering(); + if (UseNewSMEABILowering) { if (CallAttrs.requiresLazySave() || CallAttrs.requiresPreservingAllZAState()) diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp index c39a5cc2fcb16..cced0faa28889 100644 --- a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp +++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp @@ -110,6 +110,71 @@ struct PhysRegSave { Register X0Save = AArch64::NoRegister; }; +/// Contains the needed ZA state (and live registers) at an instruction. That is +/// the state ZA must be in _before_ "InsertPt". +struct InstInfo { + ZAState NeededState{ZAState::ANY}; + MachineBasicBlock::iterator InsertPt; + LiveRegs PhysLiveRegs = LiveRegs::None; +}; + +/// Contains the needed ZA state for each instruction in a block. Instructions +/// that do not require a ZA state are not recorded. +struct BlockInfo { + ZAState FixedEntryState{ZAState::ANY}; + SmallVector Insts; + LiveRegs PhysLiveRegsAtEntry = LiveRegs::None; + LiveRegs PhysLiveRegsAtExit = LiveRegs::None; +}; + +/// Contains the needed ZA state information for all blocks within a function. +struct FunctionInfo { + SmallVector Blocks; + std::optional AfterSMEProloguePt; + LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None; +}; + +/// State/helpers that is only needed when emitting code to handle +/// saving/restoring ZA. +class EmitContext { +public: + EmitContext() = default; + + /// Get or create a TPIDR2 block in \p MF. + int getTPIDR2Block(MachineFunction &MF) { + if (TPIDR2BlockFI) + return *TPIDR2BlockFI; + MachineFrameInfo &MFI = MF.getFrameInfo(); + TPIDR2BlockFI = MFI.CreateStackObject(16, Align(16), false); + return *TPIDR2BlockFI; + } + + /// Get or create agnostic ZA buffer pointer in \p MF. + Register getAgnosticZABufferPtr(MachineFunction &MF) { + if (AgnosticZABufferPtr != AArch64::NoRegister) + return AgnosticZABufferPtr; + Register BufferPtr = + MF.getInfo()->getEarlyAllocSMESaveBuffer(); + AgnosticZABufferPtr = + BufferPtr != AArch64::NoRegister + ? BufferPtr + : MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass); + return AgnosticZABufferPtr; + } + + /// Returns true if the function must allocate a ZA save buffer on entry. This + /// will be the case if, at any point in the function, a ZA save was emitted. + bool needsSaveBuffer() const { + assert(!(TPIDR2BlockFI && AgnosticZABufferPtr) && + "Cannot have both a TPIDR2 block and agnostic ZA buffer"); + return TPIDR2BlockFI || AgnosticZABufferPtr != AArch64::NoRegister; + } + +private: + std::optional TPIDR2BlockFI; + Register AgnosticZABufferPtr = AArch64::NoRegister; +}; + static bool isLegalEdgeBundleZAState(ZAState State) { switch (State) { case ZAState::ACTIVE: @@ -119,9 +184,6 @@ static bool isLegalEdgeBundleZAState(ZAState State) { return false; } } -struct TPIDR2State { - int FrameIndex = -1; -}; StringRef getZAStateString(ZAState State) { #define MAKE_CASE(V) \ @@ -192,25 +254,28 @@ struct MachineSMEABI : public MachineFunctionPass { /// Collects the needed ZA state (and live registers) before each instruction /// within the machine function. - void collectNeededZAStates(SMEAttrs); + FunctionInfo collectNeededZAStates(SMEAttrs SMEFnAttrs); /// Assigns each edge bundle a ZA state based on the needed states of blocks /// that have incoming or outgoing edges in that bundle. - void assignBundleZAStates(); + SmallVector assignBundleZAStates(const EdgeBundles &Bundles, + const FunctionInfo &FnInfo); /// Inserts code to handle changes between ZA states within the function. /// E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA. - void insertStateChanges(); + void insertStateChanges(EmitContext &, const FunctionInfo &FnInfo, + const EdgeBundles &Bundles, + ArrayRef BundleStates); // Emission routines for private and shared ZA functions (using lazy saves). void emitNewZAPrologue(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI); - void emitRestoreLazySave(MachineBasicBlock &MBB, + void emitRestoreLazySave(EmitContext &, MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs); - void emitSetupLazySave(MachineBasicBlock &MBB, + void emitSetupLazySave(EmitContext &, MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI); - void emitAllocateLazySaveBuffer(MachineBasicBlock &MBB, + void emitAllocateLazySaveBuffer(EmitContext &, MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI); void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, bool ClearTPIDR2); @@ -222,78 +287,49 @@ struct MachineSMEABI : public MachineFunctionPass { // Emit a "full" ZA save or restore. It is "full" in the sense that this // function will emit a call to __arm_sme_save or __arm_sme_restore, which // handles saving and restoring both ZA and ZT0. - void emitFullZASaveRestore(MachineBasicBlock &MBB, + void emitFullZASaveRestore(EmitContext &, MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs, bool IsSave); - void emitAllocateFullZASaveBuffer(MachineBasicBlock &MBB, + void emitAllocateFullZASaveBuffer(EmitContext &, MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs); - void emitStateChange(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, - ZAState From, ZAState To, LiveRegs PhysLiveRegs); + void emitStateChange(EmitContext &, MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI, ZAState From, + ZAState To, LiveRegs PhysLiveRegs); // Helpers for switching between lazy/full ZA save/restore routines. - void emitZASave(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, - LiveRegs PhysLiveRegs) { + void emitZASave(EmitContext &Context, MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) { if (AFI->getSMEFnAttrs().hasAgnosticZAInterface()) - return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/true); - return emitSetupLazySave(MBB, MBBI); + return emitFullZASaveRestore(Context, MBB, MBBI, PhysLiveRegs, + /*IsSave=*/true); + return emitSetupLazySave(Context, MBB, MBBI); } - void emitZARestore(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, - LiveRegs PhysLiveRegs) { + void emitZARestore(EmitContext &Context, MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) { if (AFI->getSMEFnAttrs().hasAgnosticZAInterface()) - return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/false); - return emitRestoreLazySave(MBB, MBBI, PhysLiveRegs); + return emitFullZASaveRestore(Context, MBB, MBBI, PhysLiveRegs, + /*IsSave=*/false); + return emitRestoreLazySave(Context, MBB, MBBI, PhysLiveRegs); } - void emitAllocateZASaveBuffer(MachineBasicBlock &MBB, + void emitAllocateZASaveBuffer(EmitContext &Context, MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) { if (AFI->getSMEFnAttrs().hasAgnosticZAInterface()) - return emitAllocateFullZASaveBuffer(MBB, MBBI, PhysLiveRegs); - return emitAllocateLazySaveBuffer(MBB, MBBI); + return emitAllocateFullZASaveBuffer(Context, MBB, MBBI, PhysLiveRegs); + return emitAllocateLazySaveBuffer(Context, MBB, MBBI); } /// Save live physical registers to virtual registers. PhysRegSave createPhysRegSave(LiveRegs PhysLiveRegs, MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, DebugLoc DL); /// Restore physical registers from a save of their previous values. - void restorePhyRegSave(PhysRegSave const &RegSave, MachineBasicBlock &MBB, + void restorePhyRegSave(const PhysRegSave &RegSave, MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, DebugLoc DL); - /// Get or create a TPIDR2 block in this function. - TPIDR2State getTPIDR2Block(); - - Register getAgnosticZABufferPtr(); - private: - /// Contains the needed ZA state (and live registers) at an instruction. - struct InstInfo { - ZAState NeededState{ZAState::ANY}; - MachineBasicBlock::iterator InsertPt; - LiveRegs PhysLiveRegs = LiveRegs::None; - }; - - /// Contains the needed ZA state for each instruction in a block. - /// Instructions that do not require a ZA state are not recorded. - struct BlockInfo { - ZAState FixedEntryState{ZAState::ANY}; - SmallVector Insts; - LiveRegs PhysLiveRegsAtEntry = LiveRegs::None; - LiveRegs PhysLiveRegsAtExit = LiveRegs::None; - }; - - // All pass state that must be cleared between functions. - struct PassState { - SmallVector Blocks; - SmallVector BundleStates; - std::optional TPIDR2Block; - std::optional AfterSMEProloguePt; - Register AgnosticZABufferPtr = AArch64::NoRegister; - LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None; - } State; - MachineFunction *MF = nullptr; - EdgeBundles *Bundles = nullptr; const AArch64Subtarget *Subtarget = nullptr; const AArch64RegisterInfo *TRI = nullptr; const AArch64FunctionInfo *AFI = nullptr; @@ -301,14 +337,18 @@ struct MachineSMEABI : public MachineFunctionPass { MachineRegisterInfo *MRI = nullptr; }; -void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) { +FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) { assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() || SMEFnAttrs.hasZAState()) && "Expected function to have ZA/ZT0 state!"); - State.Blocks.resize(MF->getNumBlockIDs()); + SmallVector Blocks(MF->getNumBlockIDs()); + LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None; + std::optional AfterSMEProloguePt; + for (MachineBasicBlock &MBB : *MF) { - BlockInfo &Block = State.Blocks[MBB.getNumber()]; + BlockInfo &Block = Blocks[MBB.getNumber()]; + if (MBB.isEntryBlock()) { // Entry block: Block.FixedEntryState = SMEFnAttrs.hasPrivateZAInterface() @@ -347,8 +387,8 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) { // allocation -- which is a safe point for this pass to insert any TPIDR2 // block setup. if (MI.getOpcode() == AArch64::SMEStateAllocPseudo) { - State.AfterSMEProloguePt = MBBI; - State.PhysLiveRegsAfterSMEPrologue = PhysLiveRegs; + AfterSMEProloguePt = MBBI; + PhysLiveRegsAfterSMEPrologue = PhysLiveRegs; } // Note: We treat Agnostic ZA as inout_za with an alternate save/restore. auto [NeededState, InsertPt] = getZAStateBeforeInst( @@ -368,11 +408,18 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) { // Reverse vector (as we had to iterate backwards for liveness). std::reverse(Block.Insts.begin(), Block.Insts.end()); } + + return FunctionInfo{std::move(Blocks), AfterSMEProloguePt, + PhysLiveRegsAfterSMEPrologue}; } -void MachineSMEABI::assignBundleZAStates() { - State.BundleStates.resize(Bundles->getNumBundles()); - for (unsigned I = 0, E = Bundles->getNumBundles(); I != E; ++I) { +/// Assigns each edge bundle a ZA state based on the needed states of blocks +/// that have incoming or outgoing edges in that bundle. +SmallVector +MachineSMEABI::assignBundleZAStates(const EdgeBundles &Bundles, + const FunctionInfo &FnInfo) { + SmallVector BundleStates(Bundles.getNumBundles()); + for (unsigned I = 0, E = Bundles.getNumBundles(); I != E; ++I) { LLVM_DEBUG(dbgs() << "Assigning ZA state for edge bundle: " << I << '\n'); // Attempt to assign a ZA state for this bundle that minimizes state @@ -381,16 +428,16 @@ void MachineSMEABI::assignBundleZAStates() { // TODO: We should propagate desired incoming/outgoing states through blocks // that have the "ANY" state first to make better global decisions. int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0}; - for (unsigned BlockID : Bundles->getBlocks(I)) { + for (unsigned BlockID : Bundles.getBlocks(I)) { LLVM_DEBUG(dbgs() << "- bb." << BlockID); - const BlockInfo &Block = State.Blocks[BlockID]; + const BlockInfo &Block = FnInfo.Blocks[BlockID]; if (Block.Insts.empty()) { LLVM_DEBUG(dbgs() << " (no state preference)\n"); continue; } - bool InEdge = Bundles->getBundle(BlockID, /*Out=*/false) == I; - bool OutEdge = Bundles->getBundle(BlockID, /*Out=*/true) == I; + bool InEdge = Bundles.getBundle(BlockID, /*Out=*/false) == I; + bool OutEdge = Bundles.getBundle(BlockID, /*Out=*/true) == I; ZAState DesiredIncomingState = Block.Insts.front().NeededState; if (InEdge && isLegalEdgeBundleZAState(DesiredIncomingState)) { @@ -423,15 +470,20 @@ void MachineSMEABI::assignBundleZAStates() { dbgs() << "\n\n"; }); - State.BundleStates[I] = BundleState; + BundleStates[I] = BundleState; } + + return BundleStates; } -void MachineSMEABI::insertStateChanges() { +void MachineSMEABI::insertStateChanges(EmitContext &Context, + const FunctionInfo &FnInfo, + const EdgeBundles &Bundles, + ArrayRef BundleStates) { for (MachineBasicBlock &MBB : *MF) { - const BlockInfo &Block = State.Blocks[MBB.getNumber()]; - ZAState InState = State.BundleStates[Bundles->getBundle(MBB.getNumber(), - /*Out=*/false)]; + const BlockInfo &Block = FnInfo.Blocks[MBB.getNumber()]; + ZAState InState = BundleStates[Bundles.getBundle(MBB.getNumber(), + /*Out=*/false)]; ZAState CurrentState = Block.FixedEntryState; if (CurrentState == ZAState::ANY) @@ -439,8 +491,8 @@ void MachineSMEABI::insertStateChanges() { for (auto &Inst : Block.Insts) { if (CurrentState != Inst.NeededState) - emitStateChange(MBB, Inst.InsertPt, CurrentState, Inst.NeededState, - Inst.PhysLiveRegs); + emitStateChange(Context, MBB, Inst.InsertPt, CurrentState, + Inst.NeededState, Inst.PhysLiveRegs); CurrentState = Inst.NeededState; } @@ -448,21 +500,13 @@ void MachineSMEABI::insertStateChanges() { continue; ZAState OutState = - State.BundleStates[Bundles->getBundle(MBB.getNumber(), /*Out=*/true)]; + BundleStates[Bundles.getBundle(MBB.getNumber(), /*Out=*/true)]; if (CurrentState != OutState) - emitStateChange(MBB, MBB.getFirstTerminator(), CurrentState, OutState, - Block.PhysLiveRegsAtExit); + emitStateChange(Context, MBB, MBB.getFirstTerminator(), CurrentState, + OutState, Block.PhysLiveRegsAtExit); } } -TPIDR2State MachineSMEABI::getTPIDR2Block() { - if (State.TPIDR2Block) - return *State.TPIDR2Block; - MachineFrameInfo &MFI = MF->getFrameInfo(); - State.TPIDR2Block = TPIDR2State{MFI.CreateStackObject(16, Align(16), false)}; - return *State.TPIDR2Block; -} - static DebugLoc getDebugLoc(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) { if (MBBI != MBB.end()) @@ -470,7 +514,8 @@ static DebugLoc getDebugLoc(MachineBasicBlock &MBB, return DebugLoc(); } -void MachineSMEABI::emitSetupLazySave(MachineBasicBlock &MBB, +void MachineSMEABI::emitSetupLazySave(EmitContext &Context, + MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) { DebugLoc DL = getDebugLoc(MBB, MBBI); @@ -478,7 +523,7 @@ void MachineSMEABI::emitSetupLazySave(MachineBasicBlock &MBB, Register TPIDR2 = MRI->createVirtualRegister(&AArch64::GPR64spRegClass); Register TPIDR2Ptr = MRI->createVirtualRegister(&AArch64::GPR64RegClass); BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), TPIDR2) - .addFrameIndex(getTPIDR2Block().FrameIndex) + .addFrameIndex(Context.getTPIDR2Block(*MF)) .addImm(0) .addImm(0); BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), TPIDR2Ptr) @@ -512,7 +557,7 @@ PhysRegSave MachineSMEABI::createPhysRegSave(LiveRegs PhysLiveRegs, return RegSave; } -void MachineSMEABI::restorePhyRegSave(PhysRegSave const &RegSave, +void MachineSMEABI::restorePhyRegSave(const PhysRegSave &RegSave, MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, DebugLoc DL) { @@ -528,7 +573,8 @@ void MachineSMEABI::restorePhyRegSave(PhysRegSave const &RegSave, .addReg(RegSave.X0Save); } -void MachineSMEABI::emitRestoreLazySave(MachineBasicBlock &MBB, +void MachineSMEABI::emitRestoreLazySave(EmitContext &Context, + MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) { auto *TLI = Subtarget->getTargetLowering(); @@ -548,7 +594,7 @@ void MachineSMEABI::emitRestoreLazySave(MachineBasicBlock &MBB, .addImm(AArch64SysReg::TPIDR2_EL0); // Get pointer to TPIDR2 block. BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), TPIDR2) - .addFrameIndex(getTPIDR2Block().FrameIndex) + .addFrameIndex(Context.getTPIDR2Block(*MF)) .addImm(0) .addImm(0); // (Conditionally) restore ZA state. @@ -582,7 +628,8 @@ void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB, } void MachineSMEABI::emitAllocateLazySaveBuffer( - MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) { + EmitContext &Context, MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI) { MachineFrameInfo &MFI = MF->getFrameInfo(); DebugLoc DL = getDebugLoc(MBB, MBBI); Register SP = MRI->createVirtualRegister(&AArch64::GPR64RegClass); @@ -630,7 +677,7 @@ void MachineSMEABI::emitAllocateLazySaveBuffer( BuildMI(MBB, MBBI, DL, TII->get(AArch64::STPXi)) .addReg(Buffer) .addReg(SVL) - .addFrameIndex(getTPIDR2Block().FrameIndex) + .addFrameIndex(Context.getTPIDR2Block(*MF)) .addImm(0); } } @@ -662,18 +709,8 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB, .addImm(1); } -Register MachineSMEABI::getAgnosticZABufferPtr() { - if (State.AgnosticZABufferPtr != AArch64::NoRegister) - return State.AgnosticZABufferPtr; - Register BufferPtr = AFI->getEarlyAllocSMESaveBuffer(); - State.AgnosticZABufferPtr = - BufferPtr != AArch64::NoRegister - ? BufferPtr - : MF->getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass); - return State.AgnosticZABufferPtr; -} - -void MachineSMEABI::emitFullZASaveRestore(MachineBasicBlock &MBB, +void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context, + MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs, bool IsSave) { auto *TLI = Subtarget->getTargetLowering(); @@ -684,7 +721,7 @@ void MachineSMEABI::emitFullZASaveRestore(MachineBasicBlock &MBB, // Copy the buffer pointer into X0. BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr) - .addReg(getAgnosticZABufferPtr()); + .addReg(Context.getAgnosticZABufferPtr(*MF)); // Call __arm_sme_save/__arm_sme_restore. BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL)) @@ -699,14 +736,14 @@ void MachineSMEABI::emitFullZASaveRestore(MachineBasicBlock &MBB, } void MachineSMEABI::emitAllocateFullZASaveBuffer( - MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, - LiveRegs PhysLiveRegs) { + EmitContext &Context, MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) { // Buffer already allocated in SelectionDAG. if (AFI->getEarlyAllocSMESaveBuffer()) return; DebugLoc DL = getDebugLoc(MBB, MBBI); - Register BufferPtr = getAgnosticZABufferPtr(); + Register BufferPtr = Context.getAgnosticZABufferPtr(*MF); Register BufferSize = MRI->createVirtualRegister(&AArch64::GPR64RegClass); PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL); @@ -742,11 +779,11 @@ void MachineSMEABI::emitAllocateFullZASaveBuffer( restorePhyRegSave(RegSave, MBB, MBBI, DL); } -void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB, +void MachineSMEABI::emitStateChange(EmitContext &Context, + MachineBasicBlock &MBB, MachineBasicBlock::iterator InsertPt, ZAState From, ZAState To, LiveRegs PhysLiveRegs) { - // ZA not used. if (From == ZAState::ANY || To == ZAState::ANY) return; @@ -774,9 +811,9 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB, } if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED) - emitZASave(MBB, InsertPt, PhysLiveRegs); + emitZASave(Context, MBB, InsertPt, PhysLiveRegs); else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE) - emitZARestore(MBB, InsertPt, PhysLiveRegs); + emitZARestore(Context, MBB, InsertPt, PhysLiveRegs); else if (To == ZAState::OFF) { assert(From != ZAState::CALLER_DORMANT && "CALLER_DORMANT to OFF should have already been handled"); @@ -807,32 +844,33 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) { assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!"); - // Reset pass state. - State = PassState{}; this->MF = &MF; - Bundles = &getAnalysis().getEdgeBundles(); Subtarget = &MF.getSubtarget(); TII = Subtarget->getInstrInfo(); TRI = Subtarget->getRegisterInfo(); MRI = &MF.getRegInfo(); - collectNeededZAStates(SMEFnAttrs); - assignBundleZAStates(); - insertStateChanges(); + const EdgeBundles &Bundles = + getAnalysis().getEdgeBundles(); + + FunctionInfo FnInfo = collectNeededZAStates(SMEFnAttrs); + SmallVector BundleStates = assignBundleZAStates(Bundles, FnInfo); + + EmitContext Context; + insertStateChanges(Context, FnInfo, Bundles, BundleStates); - // Allocate save buffer (if needed). - if (State.AgnosticZABufferPtr != AArch64::NoRegister || State.TPIDR2Block) { - if (State.AfterSMEProloguePt) { + if (Context.needsSaveBuffer()) { + if (FnInfo.AfterSMEProloguePt) { // Note: With inline stack probes the AfterSMEProloguePt may not be in the // entry block (due to the probing loop). - emitAllocateZASaveBuffer(*(*State.AfterSMEProloguePt)->getParent(), - *State.AfterSMEProloguePt, - State.PhysLiveRegsAfterSMEPrologue); + MachineBasicBlock::iterator MBBI = *FnInfo.AfterSMEProloguePt; + emitAllocateZASaveBuffer(Context, *MBBI->getParent(), MBBI, + FnInfo.PhysLiveRegsAfterSMEPrologue); } else { MachineBasicBlock &EntryBlock = MF.front(); emitAllocateZASaveBuffer( - EntryBlock, EntryBlock.getFirstNonPHI(), - State.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry); + Context, EntryBlock, EntryBlock.getFirstNonPHI(), + FnInfo.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry); } }