Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1717,6 +1717,7 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
}
case AArch64::InOutZAUsePseudo:
case AArch64::RequiresZASavePseudo:
case AArch64::RequiresZT0SavePseudo:
case AArch64::SMEStateAllocPseudo:
case AArch64::COALESCER_BARRIER_FPR16:
case AArch64::COALESCER_BARRIER_FPR32:
Expand Down
11 changes: 8 additions & 3 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9457,6 +9457,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
if (CallAttrs.requiresLazySave() ||
CallAttrs.requiresPreservingAllZAState())
ZAMarkerNode = AArch64ISD::REQUIRES_ZA_SAVE;
else if (CallAttrs.requiresPreservingZT0())
ZAMarkerNode = AArch64ISD::REQUIRES_ZT0_SAVE;
else if (CallAttrs.caller().hasZAState() ||
CallAttrs.caller().hasZT0State())
ZAMarkerNode = AArch64ISD::INOUT_ZA_USE;
Expand Down Expand Up @@ -9576,7 +9578,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,

SDValue ZTFrameIdx;
MachineFrameInfo &MFI = MF.getFrameInfo();
bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
bool ShouldPreserveZT0 =
!UseNewSMEABILowering && CallAttrs.requiresPreservingZT0();

// If the caller has ZT0 state which will not be preserved by the callee,
// spill ZT0 before the call.
Expand All @@ -9589,7 +9592,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,

// If caller shares ZT0 but the callee is not shared ZA, we need to stop
// PSTATE.ZA before the call if there is no lazy-save active.
bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
bool DisableZA =
!UseNewSMEABILowering && CallAttrs.requiresDisablingZABeforeCall();
assert((!DisableZA || !RequiresLazySave) &&
"Lazy-save should have PSTATE.SM=1 on entry to the function");

Expand Down Expand Up @@ -10074,7 +10078,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
getSMToggleCondition(CallAttrs));
}

if (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall())
if (!UseNewSMEABILowering &&
(RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall()))
// Unconditionally resume ZA.
Result = DAG.getNode(
AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)),
let hasSideEffects = 1, isMeta = 1 in {
def InOutZAUsePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
def RequiresZASavePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
def RequiresZT0SavePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
}

def SMEStateAllocPseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
Expand All @@ -122,6 +123,11 @@ def AArch64_requires_za_save
[SDNPHasChain, SDNPInGlue]>;
def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>;

def AArch64_requires_zt0_save
: SDNode<"AArch64ISD::REQUIRES_ZT0_SAVE", SDTypeProfile<0, 0, []>,
[SDNPHasChain, SDNPInGlue]>;
def : Pat<(AArch64_requires_zt0_save), (RequiresZT0SavePseudo)>;

def AArch64_sme_state_alloc
: SDNode<"AArch64ISD::SME_STATE_ALLOC", SDTypeProfile<0, 0,[]>,
[SDNPHasChain]>;
Expand Down
176 changes: 150 additions & 26 deletions llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,30 @@ using namespace llvm;

namespace {

enum ZAState {
// Note: For agnostic ZA, we assume the function is always entered/exited in the
// "ACTIVE" state -- this _may_ not be the case (since OFF is also a
// possibility, but for the purpose of placing ZA saves/restores, that does not
// matter).
enum ZAState : uint8_t {
// Any/unknown state (not valid)
ANY = 0,

// ZA is in use and active (i.e. within the accumulator)
ACTIVE,

// ZA is active, but ZT0 has been saved.
// This handles the edge case of sharedZA && !sharesZT0.
ACTIVE_ZT0_SAVED,

// A ZA save has been set up or committed (i.e. ZA is dormant or off)
// If the function uses ZT0 it must also be saved.
LOCAL_SAVED,

// ZA has been committed to the lazy save buffer of the current function.
// If the function uses ZT0 it must also be saved.
// ZA is off when a save has been committed.
LOCAL_COMMITTED,

// The ZA/ZT0 state on entry to the function.
ENTRY,

Expand Down Expand Up @@ -164,6 +178,14 @@ class EmitContext {
return AgnosticZABufferPtr;
}

int getZT0SaveSlot(MachineFunction &MF) {
if (ZT0SaveFI)
return *ZT0SaveFI;
MachineFrameInfo &MFI = MF.getFrameInfo();
ZT0SaveFI = MFI.CreateSpillStackObject(64, Align(16));
return *ZT0SaveFI;
}

/// Returns true if the function must allocate a ZA save buffer on entry. This
/// will be the case if, at any point in the function, a ZA save was emitted.
bool needsSaveBuffer() const {
Expand All @@ -173,6 +195,7 @@ class EmitContext {
}

private:
std::optional<int> ZT0SaveFI;
std::optional<int> TPIDR2BlockFI;
Register AgnosticZABufferPtr = AArch64::NoRegister;
};
Expand All @@ -184,8 +207,10 @@ class EmitContext {
/// state would not be legal, as transitioning to it drops the content of ZA.
static bool isLegalEdgeBundleZAState(ZAState State) {
switch (State) {
case ZAState::ACTIVE: // ZA state within the accumulator/ZT0.
case ZAState::LOCAL_SAVED: // ZA state is saved on the stack.
case ZAState::ACTIVE: // ZA state within the accumulator/ZT0.
case ZAState::ACTIVE_ZT0_SAVED: // ZT0 is saved (ZA is active).
case ZAState::LOCAL_SAVED: // ZA state may be saved on the stack.
case ZAState::LOCAL_COMMITTED: // ZA state is saved on the stack.
return true;
default:
return false;
Expand All @@ -199,7 +224,9 @@ StringRef getZAStateString(ZAState State) {
switch (State) {
MAKE_CASE(ZAState::ANY)
MAKE_CASE(ZAState::ACTIVE)
MAKE_CASE(ZAState::ACTIVE_ZT0_SAVED)
MAKE_CASE(ZAState::LOCAL_SAVED)
MAKE_CASE(ZAState::LOCAL_COMMITTED)
MAKE_CASE(ZAState::ENTRY)
MAKE_CASE(ZAState::OFF)
default:
Expand All @@ -221,18 +248,34 @@ static bool isZAorZTRegOp(const TargetRegisterInfo &TRI,
/// Returns the required ZA state needed before \p MI and an iterator pointing
/// to where any code required to change the ZA state should be inserted.
static std::pair<ZAState, MachineBasicBlock::iterator>
getZAStateBeforeInst(const TargetRegisterInfo &TRI, MachineInstr &MI,
bool ZAOffAtReturn) {
getInstNeededZAState(const TargetRegisterInfo &TRI, MachineInstr &MI,
SMEAttrs SMEFnAttrs) {
MachineBasicBlock::iterator InsertPt(MI);

if (MI.getOpcode() == AArch64::InOutZAUsePseudo)
return {ZAState::ACTIVE, std::prev(InsertPt)};

// Note: If we need to save both ZA and ZT0 we use RequiresZASavePseudo.
if (MI.getOpcode() == AArch64::RequiresZASavePseudo)
return {ZAState::LOCAL_SAVED, std::prev(InsertPt)};

if (MI.isReturn())
// If we only need to save ZT0 there's two cases to consider:
// 1. The function has ZA state (that we don't need to save).
// - In this case we switch to the "ACTIVE_ZT0_SAVED" state.
// This only saves ZT0.
// 2. The function does not have ZA state
// - In this case we switch to "LOCAL_COMMITTED" state.
// This saves ZT0 and turns ZA off.
if (MI.getOpcode() == AArch64::RequiresZT0SavePseudo) {
return {SMEFnAttrs.hasZAState() ? ZAState::ACTIVE_ZT0_SAVED
: ZAState::LOCAL_COMMITTED,
std::prev(InsertPt)};
}

if (MI.isReturn()) {
bool ZAOffAtReturn = SMEFnAttrs.hasPrivateZAInterface();
return {ZAOffAtReturn ? ZAState::OFF : ZAState::ACTIVE, InsertPt};
}

for (auto &MO : MI.operands()) {
if (isZAorZTRegOp(TRI, MO))
Expand Down Expand Up @@ -280,6 +323,9 @@ struct MachineSMEABI : public MachineFunctionPass {
/// predecessors).
void propagateDesiredStates(FunctionInfo &FnInfo, bool Forwards = true);

void emitZT0SaveRestore(EmitContext &, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI, bool IsSave);

// Emission routines for private and shared ZA functions (using lazy saves).
void emitSMEPrologue(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
Expand All @@ -290,8 +336,8 @@ struct MachineSMEABI : public MachineFunctionPass {
MachineBasicBlock::iterator MBBI);
void emitAllocateLazySaveBuffer(EmitContext &, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
bool ClearTPIDR2);
void emitZAMode(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
bool ClearTPIDR2, bool On);

// Emission routines for agnostic ZA functions.
void emitSetupFullZASave(MachineBasicBlock &MBB,
Expand Down Expand Up @@ -398,7 +444,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
Block.FixedEntryState = ZAState::ENTRY;
} else if (MBB.isEHPad()) {
// EH entry block:
Block.FixedEntryState = ZAState::LOCAL_SAVED;
Block.FixedEntryState = ZAState::LOCAL_COMMITTED;
}

LiveRegUnits LiveUnits(*TRI);
Expand All @@ -420,8 +466,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
}
// Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
auto [NeededState, InsertPt] = getZAStateBeforeInst(
*TRI, MI, /*ZAOffAtReturn=*/SMEFnAttrs.hasPrivateZAInterface());
auto [NeededState, InsertPt] = getInstNeededZAState(*TRI, MI, SMEFnAttrs);
assert((InsertPt == MBBI ||
InsertPt->getOpcode() == AArch64::ADJCALLSTACKDOWN) &&
"Unexpected state change insertion point!");
Expand Down Expand Up @@ -742,9 +787,9 @@ void MachineSMEABI::emitRestoreLazySave(EmitContext &Context,
restorePhyRegSave(RegSave, MBB, MBBI, DL);
}

void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
bool ClearTPIDR2) {
void MachineSMEABI::emitZAMode(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
bool ClearTPIDR2, bool On) {
DebugLoc DL = getDebugLoc(MBB, MBBI);

if (ClearTPIDR2)
Expand All @@ -755,7 +800,7 @@ void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB,
// Disable ZA.
BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
.addImm(AArch64SVCR::SVCRZA)
.addImm(0);
.addImm(On ? 1 : 0);
}

void MachineSMEABI::emitAllocateLazySaveBuffer(
Expand Down Expand Up @@ -884,6 +929,28 @@ void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context,
restorePhyRegSave(RegSave, MBB, MBBI, DL);
}

void MachineSMEABI::emitZT0SaveRestore(EmitContext &Context,
MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
bool IsSave) {
DebugLoc DL = getDebugLoc(MBB, MBBI);
Register ZT0Save = MRI->createVirtualRegister(&AArch64::GPR64spRegClass);

BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), ZT0Save)
.addFrameIndex(Context.getZT0SaveSlot(*MF))
.addImm(0)
.addImm(0);

if (IsSave) {
BuildMI(MBB, MBBI, DL, TII->get(AArch64::STR_TX))
.addReg(AArch64::ZT0)
.addReg(ZT0Save);
} else {
BuildMI(MBB, MBBI, DL, TII->get(AArch64::LDR_TX), AArch64::ZT0)
.addReg(ZT0Save);
}
}

void MachineSMEABI::emitAllocateFullZASaveBuffer(
EmitContext &Context, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) {
Expand Down Expand Up @@ -928,6 +995,17 @@ void MachineSMEABI::emitAllocateFullZASaveBuffer(
restorePhyRegSave(RegSave, MBB, MBBI, DL);
}

struct FromState {
ZAState From;

constexpr uint8_t to(ZAState To) const {
static_assert(NUM_ZA_STATE < 16, "expected ZAState to fit in 4-bits");
return uint8_t(From) << 4 | uint8_t(To);
}
};

constexpr FromState transitionFrom(ZAState From) { return FromState{From}; }

void MachineSMEABI::emitStateChange(EmitContext &Context,
MachineBasicBlock &MBB,
MachineBasicBlock::iterator InsertPt,
Expand Down Expand Up @@ -959,17 +1037,63 @@ void MachineSMEABI::emitStateChange(EmitContext &Context,
From = ZAState::ACTIVE;
}

if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
emitZASave(Context, MBB, InsertPt, PhysLiveRegs);
else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
else if (To == ZAState::OFF) {
assert(From != ZAState::ENTRY &&
"ENTRY to OFF should have already been handled");
assert(!SMEFnAttrs.hasAgnosticZAInterface() &&
"Should not turn ZA off in agnostic ZA function");
emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED);
} else {
bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface();
bool HasZT0State = SMEFnAttrs.hasZT0State();
bool HasZAState = IsAgnosticZA || SMEFnAttrs.hasZAState();

switch (transitionFrom(From).to(To)) {
// This section handles: ACTIVE <-> ACTIVE_ZT0_SAVED
case transitionFrom(ZAState::ACTIVE).to(ZAState::ACTIVE_ZT0_SAVED):
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
break;
case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::ACTIVE):
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false);
break;

// This section handles: ACTIVE -> LOCAL_SAVED
case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_SAVED):
if (HasZT0State)
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
if (HasZAState)
emitZASave(Context, MBB, InsertPt, PhysLiveRegs);
break;

// This section handles: ACTIVE -> LOCAL_COMMITTED
case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_COMMITTED):
// Note: We could support ZA state here, but this transition is currently
// only possible when we _don't_ have ZA state.
assert(HasZT0State && !HasZAState && "Expect to only have ZT0 state.");
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/false);
break;

// This section handles: LOCAL_COMMITTED -> (OFF|LOCAL_SAVED)
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::OFF):
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::LOCAL_SAVED):
// These transistions are a no-op.
break;

// This section handles: LOCAL_(SAVED|COMMITTED) -> ACTIVE[_ZT0_SAVED]
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE):
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE_ZT0_SAVED):
case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::ACTIVE):
if (HasZAState)
emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
else
emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/true);
if (HasZT0State && To == ZAState::ACTIVE)
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false);
break;
default:
if (To == ZAState::OFF) {
assert(From != ZAState::ENTRY &&
"ENTRY to OFF should have already been handled");
assert(SMEFnAttrs.hasPrivateZAInterface() &&
"Did not expect to turn ZA off in shared/agnostic ZA function");
emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED,
/*On=*/false);
break;
}
dbgs() << "Error: Transition from " << getZAStateString(From) << " to "
<< getZAStateString(To) << '\n';
llvm_unreachable("Unimplemented state transition");
Expand Down
4 changes: 0 additions & 4 deletions llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,6 @@ define void @test7() nounwind "aarch64_inout_zt0" {
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: smstop za
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: smstop za
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
Expand Down
Loading