diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 60aa61e993b26..30f961043e78b 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -8735,15 +8735,6 @@ SDValue AArch64TargetLowering::LowerFormalArguments( } } - if (getTM().useNewSMEABILowering()) { - // Clear new ZT0 state. TODO: Move this to the SME ABI pass. - if (Attrs.isNewZT0()) - Chain = DAG.getNode( - ISD::INTRINSIC_VOID, DL, MVT::Other, Chain, - DAG.getConstant(Intrinsic::aarch64_sme_zero_zt, DL, MVT::i32), - DAG.getTargetConstant(0, DL, MVT::i32)); - } - return Chain; } diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp index 8f9aae944ad6d..bb4dfe8c60904 100644 --- a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp +++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp @@ -82,8 +82,8 @@ enum ZAState { // A ZA save has been set up or committed (i.e. ZA is dormant or off) LOCAL_SAVED, - // ZA is off or a lazy save has been set up by the caller - CALLER_DORMANT, + // The ZA/ZT0 state on entry to the function. + ENTRY, // ZA is off OFF, @@ -200,7 +200,7 @@ StringRef getZAStateString(ZAState State) { MAKE_CASE(ZAState::ANY) MAKE_CASE(ZAState::ACTIVE) MAKE_CASE(ZAState::LOCAL_SAVED) - MAKE_CASE(ZAState::CALLER_DORMANT) + MAKE_CASE(ZAState::ENTRY) MAKE_CASE(ZAState::OFF) default: llvm_unreachable("Unexpected ZAState"); @@ -281,8 +281,8 @@ struct MachineSMEABI : public MachineFunctionPass { void propagateDesiredStates(FunctionInfo &FnInfo, bool Forwards = true); // Emission routines for private and shared ZA functions (using lazy saves). - void emitNewZAPrologue(MachineBasicBlock &MBB, - MachineBasicBlock::iterator MBBI); + void emitSMEPrologue(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI); void emitRestoreLazySave(EmitContext &, MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs); @@ -395,9 +395,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) { if (MBB.isEntryBlock()) { // Entry block: - Block.FixedEntryState = SMEFnAttrs.hasPrivateZAInterface() - ? ZAState::CALLER_DORMANT - : ZAState::ACTIVE; + Block.FixedEntryState = ZAState::ENTRY; } else if (MBB.isEHPad()) { // EH entry block: Block.FixedEntryState = ZAState::LOCAL_SAVED; @@ -815,32 +813,49 @@ void MachineSMEABI::emitAllocateLazySaveBuffer( } } -void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB, - MachineBasicBlock::iterator MBBI) { +static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111; + +void MachineSMEABI::emitSMEPrologue(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI) { auto *TLI = Subtarget->getTargetLowering(); DebugLoc DL = getDebugLoc(MBB, MBBI); - // Get current TPIDR2_EL0. - Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass); - BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS)) - .addReg(TPIDR2EL0, RegState::Define) - .addImm(AArch64SysReg::TPIDR2_EL0); - // If TPIDR2_EL0 is non-zero, commit the lazy save. - // NOTE: Functions that only use ZT0 don't need to zero ZA. - bool ZeroZA = AFI->getSMEFnAttrs().hasZAState(); - auto CommitZASave = - BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo)) - .addReg(TPIDR2EL0) - .addImm(ZeroZA ? 1 : 0) - .addImm(/*ZeroZT0=*/false) - .addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_TPIDR2_SAVE)) - .addRegMask(TRI->SMEABISupportRoutinesCallPreservedMaskFromX0()); - if (ZeroZA) - CommitZASave.addDef(AArch64::ZAB0, RegState::ImplicitDefine); - // Enable ZA (as ZA could have previously been in the OFF state). - BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1)) - .addImm(AArch64SVCR::SVCRZA) - .addImm(1); + bool ZeroZA = AFI->getSMEFnAttrs().isNewZA(); + bool ZeroZT0 = AFI->getSMEFnAttrs().isNewZT0(); + if (AFI->getSMEFnAttrs().hasPrivateZAInterface()) { + // Get current TPIDR2_EL0. + Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass); + BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS)) + .addReg(TPIDR2EL0, RegState::Define) + .addImm(AArch64SysReg::TPIDR2_EL0); + // If TPIDR2_EL0 is non-zero, commit the lazy save. + // NOTE: Functions that only use ZT0 don't need to zero ZA. + auto CommitZASave = + BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo)) + .addReg(TPIDR2EL0) + .addImm(ZeroZA) + .addImm(ZeroZT0) + .addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_TPIDR2_SAVE)) + .addRegMask(TRI->SMEABISupportRoutinesCallPreservedMaskFromX0()); + if (ZeroZA) + CommitZASave.addDef(AArch64::ZAB0, RegState::ImplicitDefine); + if (ZeroZT0) + CommitZASave.addDef(AArch64::ZT0, RegState::ImplicitDefine); + // Enable ZA (as ZA could have previously been in the OFF state). + BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1)) + .addImm(AArch64SVCR::SVCRZA) + .addImm(1); + } else if (AFI->getSMEFnAttrs().hasSharedZAInterface()) { + if (ZeroZA) { + BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_M)) + .addImm(ZERO_ALL_ZA_MASK) + .addDef(AArch64::ZAB0, RegState::ImplicitDefine); + } + if (ZeroZT0) { + DebugLoc DL = getDebugLoc(MBB, MBBI); + BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_T)).addDef(AArch64::ZT0); + } + } } void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context, @@ -922,19 +937,19 @@ void MachineSMEABI::emitStateChange(EmitContext &Context, if (From == ZAState::ANY || To == ZAState::ANY) return; - // If we're exiting from the CALLER_DORMANT state that means this new ZA - // function did not touch ZA (so ZA was never turned on). - if (From == ZAState::CALLER_DORMANT && To == ZAState::OFF) + // If we're exiting from the ENTRY state that means that the function has not + // used ZA, so in the case of private ZA/ZT0 functions we can omit any set up. + if (From == ZAState::ENTRY && To == ZAState::OFF) return; + SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs(); + // TODO: Avoid setting up the save buffer if there's no transition to // LOCAL_SAVED. - if (From == ZAState::CALLER_DORMANT) { - assert(AFI->getSMEFnAttrs().hasPrivateZAInterface() && - "CALLER_DORMANT state requires private ZA interface"); + if (From == ZAState::ENTRY) { assert(&MBB == &MBB.getParent()->front() && - "CALLER_DORMANT state only valid in entry block"); - emitNewZAPrologue(MBB, MBB.getFirstNonPHI()); + "ENTRY state only valid in entry block"); + emitSMEPrologue(MBB, MBB.getFirstNonPHI()); if (To == ZAState::ACTIVE) return; // Nothing more to do (ZA is active after the prologue). @@ -949,9 +964,9 @@ void MachineSMEABI::emitStateChange(EmitContext &Context, else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE) emitZARestore(Context, MBB, InsertPt, PhysLiveRegs); else if (To == ZAState::OFF) { - assert(From != ZAState::CALLER_DORMANT && - "CALLER_DORMANT to OFF should have already been handled"); - assert(!AFI->getSMEFnAttrs().hasAgnosticZAInterface() && + assert(From != ZAState::ENTRY && + "ENTRY to OFF should have already been handled"); + assert(!SMEFnAttrs.hasAgnosticZAInterface() && "Should not turn ZA off in agnostic ZA function"); emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED); } else { diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll index 5b81f5dafe421..4c48e41294a3a 100644 --- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll +++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll @@ -199,9 +199,9 @@ define void @zt0_new_caller_zt0_new_callee(ptr %callee) "aarch64_new_zt0" nounwi ; CHECK-NEWLOWERING-NEXT: // %bb.1: ; CHECK-NEWLOWERING-NEXT: bl __arm_tpidr2_save ; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEWLOWERING-NEXT: zero { zt0 } ; CHECK-NEWLOWERING-NEXT: .LBB6_2: ; CHECK-NEWLOWERING-NEXT: smstart za -; CHECK-NEWLOWERING-NEXT: zero { zt0 } ; CHECK-NEWLOWERING-NEXT: mov x19, sp ; CHECK-NEWLOWERING-NEXT: str zt0, [x19] ; CHECK-NEWLOWERING-NEXT: smstop za @@ -252,9 +252,9 @@ define i64 @zt0_new_caller_abi_routine_callee() "aarch64_new_zt0" nounwind { ; CHECK-NEWLOWERING-NEXT: // %bb.1: ; CHECK-NEWLOWERING-NEXT: bl __arm_tpidr2_save ; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEWLOWERING-NEXT: zero { zt0 } ; CHECK-NEWLOWERING-NEXT: .LBB7_2: ; CHECK-NEWLOWERING-NEXT: smstart za -; CHECK-NEWLOWERING-NEXT: zero { zt0 } ; CHECK-NEWLOWERING-NEXT: mov x19, sp ; CHECK-NEWLOWERING-NEXT: str zt0, [x19] ; CHECK-NEWLOWERING-NEXT: bl __arm_sme_state @@ -302,9 +302,9 @@ define void @zt0_new_caller(ptr %callee) "aarch64_new_zt0" nounwind { ; CHECK-NEWLOWERING-NEXT: // %bb.1: ; CHECK-NEWLOWERING-NEXT: bl __arm_tpidr2_save ; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEWLOWERING-NEXT: zero { zt0 } ; CHECK-NEWLOWERING-NEXT: .LBB8_2: ; CHECK-NEWLOWERING-NEXT: smstart za -; CHECK-NEWLOWERING-NEXT: zero { zt0 } ; CHECK-NEWLOWERING-NEXT: blr x0 ; CHECK-NEWLOWERING-NEXT: smstop za ; CHECK-NEWLOWERING-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload @@ -343,9 +343,9 @@ define void @new_za_zt0_caller(ptr %callee) "aarch64_new_za" "aarch64_new_zt0" n ; CHECK-NEWLOWERING-NEXT: bl __arm_tpidr2_save ; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, xzr ; CHECK-NEWLOWERING-NEXT: zero {za} +; CHECK-NEWLOWERING-NEXT: zero { zt0 } ; CHECK-NEWLOWERING-NEXT: .LBB9_2: ; CHECK-NEWLOWERING-NEXT: smstart za -; CHECK-NEWLOWERING-NEXT: zero { zt0 } ; CHECK-NEWLOWERING-NEXT: blr x0 ; CHECK-NEWLOWERING-NEXT: smstop za ; CHECK-NEWLOWERING-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload @@ -356,20 +356,13 @@ define void @new_za_zt0_caller(ptr %callee) "aarch64_new_za" "aarch64_new_zt0" n ; Expect clear ZA on entry define void @new_za_shared_zt0_caller(ptr %callee) "aarch64_new_za" "aarch64_in_zt0" nounwind { -; CHECK-LABEL: new_za_shared_zt0_caller: -; CHECK: // %bb.0: -; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill -; CHECK-NEXT: zero {za} -; CHECK-NEXT: blr x0 -; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload -; CHECK-NEXT: ret -; -; CHECK-NEWLOWERING-LABEL: new_za_shared_zt0_caller: -; CHECK-NEWLOWERING: // %bb.0: -; CHECK-NEWLOWERING-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill -; CHECK-NEWLOWERING-NEXT: blr x0 -; CHECK-NEWLOWERING-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload -; CHECK-NEWLOWERING-NEXT: ret +; CHECK-COMMON-LABEL: new_za_shared_zt0_caller: +; CHECK-COMMON: // %bb.0: +; CHECK-COMMON-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-COMMON-NEXT: zero {za} +; CHECK-COMMON-NEXT: blr x0 +; CHECK-COMMON-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-COMMON-NEXT: ret call void %callee() "aarch64_inout_za" "aarch64_in_zt0"; ret void; }