Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64][SME2] Preserve ZT0 state around function calls #76968

Closed
wants to merge 2 commits into from

Conversation

kmclaughlin-arm
Copy link
Contributor

If a function has ZT0 state and calls a function which does not
preserve ZT0, the caller must save and restore ZT0 around the callee.

This patch extends SMEAttrs to interpret the following new attributes,
which apply to SME2 only:

  • aarch64_sme_pstate_zt0_new (ZT_New)
  • aarch64_sme_pstate_zt0_shared (ZT_Shared)
  • aarch64_sme_pstate_zt0_preserved (ZT_Preserved)

ZT0 must also be cleared on entry to a function marked with __arm_new_za.

If a function has ZT0 state and calls a function which does not
preserve ZT0, the caller must save and restore ZT0 around the callee.

This patch extends SMEAttrs to interpret the following new attributes,
which apply to SME2 only:
 - aarch64_sme_pstate_zt0_new (ZT_New)
 - aarch64_sme_pstate_zt0_shared (ZT_Shared)
 - aarch64_sme_pstate_zt0_preserved (ZT_Preserved)

ZT0 must also be cleared on entry to a function marked with __arm_new_za.
@llvmbot
Copy link

llvmbot commented Jan 4, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Kerry McLaughlin (kmclaughlin-arm)

Changes

If a function has ZT0 state and calls a function which does not
preserve ZT0, the caller must save and restore ZT0 around the callee.

This patch extends SMEAttrs to interpret the following new attributes,
which apply to SME2 only:

  • aarch64_sme_pstate_zt0_new (ZT_New)
  • aarch64_sme_pstate_zt0_shared (ZT_Shared)
  • aarch64_sme_pstate_zt0_preserved (ZT_Preserved)

ZT0 must also be cleared on entry to a function marked with __arm_new_za.


Patch is 26.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76968.diff

9 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64FastISel.cpp (+2-1)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+30-1)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+2)
  • (modified) llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td (+8-2)
  • (modified) llvm/lib/Target/AArch64/SMEABIPass.cpp (+14-4)
  • (modified) llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp (+15-2)
  • (modified) llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h (+19-1)
  • (added) llvm/test/CodeGen/AArch64/sme-zt0-preserve.ll (+306)
  • (modified) llvm/unittests/Target/AArch64/SMEAttributesTest.cpp (+33)
diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
index e98f6c4984a752..f63cdf8bc4f328 100644
--- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
@@ -5176,7 +5176,8 @@ FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo,
                                         const TargetLibraryInfo *LibInfo) {
 
   SMEAttrs CallerAttrs(*FuncInfo.Fn);
-  if (CallerAttrs.hasZAState() || CallerAttrs.hasStreamingInterfaceOrBody() ||
+  if (CallerAttrs.hasZAState() || CallerAttrs.hasZTState() ||
+      CallerAttrs.hasStreamingInterfaceOrBody() ||
       CallerAttrs.hasStreamingCompatibleInterface())
     return nullptr;
   return new AArch64FastISel(FuncInfo, LibInfo);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 102fd0c3dae2ab..4121621616b8bd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2338,6 +2338,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::SMSTART)
     MAKE_CASE(AArch64ISD::SMSTOP)
     MAKE_CASE(AArch64ISD::RESTORE_ZA)
+    MAKE_CASE(AArch64ISD::RESTORE_ZT)
+    MAKE_CASE(AArch64ISD::SAVE_ZT)
     MAKE_CASE(AArch64ISD::CALL)
     MAKE_CASE(AArch64ISD::ADRP)
     MAKE_CASE(AArch64ISD::ADR)
@@ -7659,6 +7661,20 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
     });
   }
 
+  SDValue ZTFrameIdx;
+  MachineFrameInfo &MFI = MF.getFrameInfo();
+  bool PreserveZT = CallerAttrs.requiresPreservingZT(CalleeAttrs);
+
+  if (PreserveZT) {
+    unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16));
+    ZTFrameIdx = DAG.getFrameIndex(
+        ZTObj,
+        DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+
+    Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other),
+                        {Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
+  }
+
   // Adjust the stack pointer for the new arguments...
   // These operations are automatically eliminated by the prolog/epilog pass
   if (!IsSibCall)
@@ -8077,6 +8093,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
           DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
           DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
 
+      if (PreserveZT)
+        Result =
+            DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
+                        {Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
+
       // Conditionally restore the lazy save using a pseudo node.
       unsigned FI = FuncInfo->getLazySaveTPIDR2Obj();
       SDValue RegMask = DAG.getRegisterMask(
@@ -8105,7 +8126,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
         DAG.getConstant(0, DL, MVT::i64));
   }
 
-  if (RequiresSMChange || RequiresLazySave) {
+  if (RequiresSMChange || RequiresLazySave || PreserveZT) {
     for (unsigned I = 0; I < InVals.size(); ++I) {
       // The smstart/smstop is chained as part of the call, but when the
       // resulting chain is discarded (which happens when the call is not part
@@ -23953,6 +23974,14 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
       return DAG.getMergeValues(
           {A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL);
     }
+    case Intrinsic::aarch64_sme_ldr_zt:
+      return DAG.getNode(AArch64ISD::RESTORE_ZT, SDLoc(N),
+                         DAG.getVTList(MVT::Other), N->getOperand(0),
+                         N->getOperand(2), N->getOperand(3));
+    case Intrinsic::aarch64_sme_str_zt:
+      return DAG.getNode(AArch64ISD::SAVE_ZT, SDLoc(N),
+                         DAG.getVTList(MVT::Other), N->getOperand(0),
+                         N->getOperand(2), N->getOperand(3));
     default:
       break;
     }
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 6ddbcd41dcb769..6c14bc0aa8dc73 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -61,6 +61,8 @@ enum NodeType : unsigned {
   SMSTART,
   SMSTOP,
   RESTORE_ZA,
+  RESTORE_ZT,
+  SAVE_ZT,
 
   // Produces the full sequence of instructions for getting the thread pointer
   // offset of a variable into X0, using the TLSDesc model.
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 380f6e1fcfdaef..eeae5303a3f898 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -22,6 +22,12 @@ def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3,
                              [SDTCisInt<0>, SDTCisPtrTy<1>]>,
                              [SDNPHasChain, SDNPSideEffect, SDNPVariadic,
                               SDNPOptInGlue]>;
+def AArch64_restore_zt : SDNode<"AArch64ISD::RESTORE_ZT", SDTypeProfile<0, 2,
+                                [SDTCisInt<0>, SDTCisPtrTy<1>]>,
+                                [SDNPHasChain, SDNPSideEffect, SDNPMayLoad]>;
+def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
+                             [SDTCisInt<0>, SDTCisPtrTy<1>]>,
+                             [SDNPHasChain, SDNPSideEffect, SDNPMayStore]>;
 
 //===----------------------------------------------------------------------===//
 // Instruction naming conventions.
@@ -543,8 +549,8 @@ defm UMOPS_MPPZZ_HtoS : sme2_int_mopx_tile<"umops", 0b101, int_aarch64_sme_umops
 
 defm ZERO_T : sme2_zero_zt<"zero", 0b0001>;
 
-defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, int_aarch64_sme_ldr_zt>;
-defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, int_aarch64_sme_str_zt>;
+defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, AArch64_restore_zt>;
+defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, AArch64_save_zt>;
 
 def MOVT_XTI : sme2_movt_zt_to_scalar<"movt", 0b0011111>;
 def MOVT_TIX : sme2_movt_scalar_to_zt<"movt", 0b0011111>;
diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp
index 3315171798d9f1..4ca0cf648bc147 100644
--- a/llvm/lib/Target/AArch64/SMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp
@@ -40,7 +40,8 @@ struct SMEABI : public FunctionPass {
   bool runOnFunction(Function &F) override;
 
 private:
-  bool updateNewZAFunctions(Module *M, Function *F, IRBuilder<> &Builder);
+  bool updateNewZAFunctions(Module *M, Function *F, IRBuilder<> &Builder,
+                            bool ClearZTState);
 };
 } // end anonymous namespace
 
@@ -82,8 +83,8 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder) {
 /// is active and we should call __arm_tpidr2_save to commit the lazy save.
 /// Additionally, PSTATE.ZA should be enabled at the beginning of the function
 /// and disabled before returning.
-bool SMEABI::updateNewZAFunctions(Module *M, Function *F,
-                                  IRBuilder<> &Builder) {
+bool SMEABI::updateNewZAFunctions(Module *M, Function *F, IRBuilder<> &Builder,
+                                  bool ClearZTState) {
   LLVMContext &Context = F->getContext();
   BasicBlock *OrigBB = &F->getEntryBlock();
 
@@ -117,6 +118,14 @@ bool SMEABI::updateNewZAFunctions(Module *M, Function *F,
   Builder.CreateCall(ZeroIntr->getFunctionType(), ZeroIntr,
                      Builder.getInt32(0xff));
 
+  // Clear ZT0 on entry to the function if required, after enabling pstate.za
+  if (ClearZTState) {
+    Function *ClearZT0Intr =
+        Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_zero_zt);
+    Builder.CreateCall(ClearZT0Intr->getFunctionType(), ClearZT0Intr,
+                       {Builder.getInt32(0)});
+  }
+
   // Before returning, disable pstate.za
   for (BasicBlock &BB : *F) {
     Instruction *T = BB.getTerminator();
@@ -143,7 +152,8 @@ bool SMEABI::runOnFunction(Function &F) {
   bool Changed = false;
   SMEAttrs FnAttrs(F);
   if (FnAttrs.hasNewZABody())
-    Changed |= updateNewZAFunctions(M, &F, Builder);
+    Changed |= updateNewZAFunctions(M, &F, Builder,
+                                    FnAttrs.requiresPreservingZT(SMEAttrs()));
 
   return Changed;
 }
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 0082b4017986c6..ef3a043a15bcc2 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -18,8 +18,10 @@ void SMEAttrs::set(unsigned M, bool Enable) {
   else
     Bitmask &= ~M;
 
+  // Streaming Mode Attrs
   assert(!(hasStreamingInterface() && hasStreamingCompatibleInterface()) &&
          "SM_Enabled and SM_Compatible are mutually exclusive");
+  // ZA Attrs
   assert(!(hasNewZABody() && hasSharedZAInterface()) &&
          "ZA_New and ZA_Shared are mutually exclusive");
   assert(!(hasNewZABody() && preservesZA()) &&
@@ -28,6 +30,11 @@ void SMEAttrs::set(unsigned M, bool Enable) {
          "ZA_New and ZA_NoLazySave are mutually exclusive");
   assert(!(hasSharedZAInterface() && (Bitmask & ZA_NoLazySave)) &&
          "ZA_Shared and ZA_NoLazySave are mutually exclusive");
+  // ZT Attrs
+  assert(!(hasNewZTBody() && hasSharedZTInterface()) &&
+         "ZT_New and ZT_Shared are mutually exclusive");
+  assert(!(hasNewZTBody() && preservesZT()) &&
+         "ZT_New and ZT_Preserved are mutually exclusive");
 }
 
 SMEAttrs::SMEAttrs(const CallBase &CB) {
@@ -40,10 +47,10 @@ SMEAttrs::SMEAttrs(const CallBase &CB) {
 SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
   if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
     Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Preserved |
-                SMEAttrs::ZA_NoLazySave);
+                SMEAttrs::ZA_NoLazySave | SMEAttrs::ZT_Preserved);
   if (FuncName == "__arm_tpidr2_restore")
     Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared |
-                SMEAttrs::ZA_NoLazySave);
+                SMEAttrs::ZA_NoLazySave | SMEAttrs::ZT_Shared);
 }
 
 SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
@@ -60,6 +67,12 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
     Bitmask |= ZA_New;
   if (Attrs.hasFnAttr("aarch64_pstate_za_preserved"))
     Bitmask |= ZA_Preserved;
+  if (Attrs.hasFnAttr("aarch64_sme_pstate_zt0_shared"))
+    Bitmask |= ZT_Shared;
+  if (Attrs.hasFnAttr("aarch64_sme_pstate_zt0_new"))
+    Bitmask |= ZT_New;
+  if (Attrs.hasFnAttr("aarch64_sme_pstate_zt0_preserved"))
+    Bitmask |= ZT_Preserved;
 }
 
 std::optional<bool>
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index e766b778b54102..3eceaf95a249a2 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -36,7 +36,10 @@ class SMEAttrs {
     ZA_New = 1 << 4,        // aarch64_pstate_sm_new
     ZA_Preserved = 1 << 5,  // aarch64_pstate_sm_preserved
     ZA_NoLazySave = 1 << 6, // Used for SME ABI routines to avoid lazy saves
-    All = ZA_Preserved - 1
+    ZT_New = 1 << 7,        // aarch64_sme_pstate_zt0_new
+    ZT_Shared = 1 << 8,     // aarch64_sme_pstate_zt0_shared
+    ZT_Preserved = 1 << 9,  // aarch64_sme_pstate_zt0_preserved
+    All = ZT_Preserved - 1
   };
 
   SMEAttrs(unsigned Mask = Normal) : Bitmask(0) { set(Mask); }
@@ -74,6 +77,14 @@ class SMEAttrs {
   requiresSMChange(const SMEAttrs &Callee,
                    bool BodyOverridesInterface = false) const;
 
+  /// \return true if a call from Caller -> Callee requires ZT0 state to be
+  /// preserved.
+  /// ZT0 must be preserved if the caller has ZT state and the callee
+  /// does not preserve ZT.
+  bool requiresPreservingZT(const SMEAttrs &Callee) const {
+    return hasZTState() && !Callee.preservesZT();
+  }
+
   // Interfaces to query PSTATE.ZA
   bool hasNewZABody() const { return Bitmask & ZA_New; }
   bool hasSharedZAInterface() const { return Bitmask & ZA_Shared; }
@@ -82,6 +93,13 @@ class SMEAttrs {
   bool hasZAState() const {
     return hasNewZABody() || hasSharedZAInterface();
   }
+
+  // Interfaces to query ZT0 state
+  bool hasNewZTBody() const { return Bitmask & ZT_New; }
+  bool hasSharedZTInterface() const { return Bitmask & ZT_Shared; }
+  bool preservesZT() const { return Bitmask & ZT_Preserved; }
+  bool hasZTState() const { return hasNewZTBody() || hasSharedZTInterface(); }
+
   bool requiresLazySave(const SMEAttrs &Callee) const {
     return hasZAState() && Callee.hasPrivateZAInterface() &&
            !(Callee.Bitmask & ZA_NoLazySave);
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-preserve.ll b/llvm/test/CodeGen/AArch64/sme-zt0-preserve.ll
new file mode 100644
index 00000000000000..bbcfd5cac197b5
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-preserve.ll
@@ -0,0 +1,306 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s
+
+; Normal callee, no ZT state
+declare void @normal_callee();
+
+; Callees with ZT state
+declare void @za_shared_callee() "aarch64_pstate_za_shared" "aarch64_sme_pstate_zt0_shared";
+declare void @za_new_callee() "aarch64_pstate_za_new" "aarch64_sme_pstate_zt0_new";
+
+; Callee with preserved ZT state
+declare void @za_preserved_callee() "aarch64_pstate_za_preserved" "aarch64_sme_pstate_zt0_preserved";
+
+
+define void @za_zt_new_caller_normal_callee() "aarch64_pstate_za_new" "aarch64_sme_pstate_zt0_new" nounwind {
+; CHECK-LABEL: za_zt_new_caller_normal_callee:
+; CHECK:       // %bb.0: // %prelude
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #80
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x8, x8, x8, x9
+; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x8, [x29, #-16]
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    cbz x8, .LBB0_2
+; CHECK-NEXT:  // %bb.1: // %save.za
+; CHECK-NEXT:    bl __arm_tpidr2_save
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:  .LBB0_2:
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    zero {za}
+; CHECK-NEXT:    zero { zt0 }
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    sub x9, x29, #16
+; CHECK-NEXT:    sub x19, x29, #80
+; CHECK-NEXT:    sturh w8, [x29, #-8]
+; CHECK-NEXT:    msr TPIDR2_EL0, x9
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    bl normal_callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    sub x0, x29, #16
+; CHECK-NEXT:    cbnz x8, .LBB0_4
+; CHECK-NEXT:  // %bb.3:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB0_4:
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    smstop za
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @normal_callee();
+  ret void;
+}
+
+define void @za_zt_new_caller_za_callee() "aarch64_pstate_za_new" "aarch64_sme_pstate_zt0_new" nounwind {
+; CHECK-LABEL: za_zt_new_caller_za_callee:
+; CHECK:       // %bb.0: // %prelude
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #144
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x8, x8, x8, x9
+; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x8, [x29, #-16]
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    cbz x8, .LBB1_2
+; CHECK-NEXT:  // %bb.1: // %save.za
+; CHECK-NEXT:    bl __arm_tpidr2_save
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:  .LBB1_2:
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    zero {za}
+; CHECK-NEXT:    zero { zt0 }
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    sub x9, x29, #16
+; CHECK-NEXT:    sub x19, x29, #80
+; CHECK-NEXT:    sturh w8, [x29, #-8]
+; CHECK-NEXT:    msr TPIDR2_EL0, x9
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    bl za_new_callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    sub x0, x29, #16
+; CHECK-NEXT:    cbnz x8, .LBB1_4
+; CHECK-NEXT:  // %bb.3:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB1_4:
+; CHECK-NEXT:    sub x8, x29, #144
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    str zt0, [x8]
+; CHECK-NEXT:    bl za_shared_callee
+; CHECK-NEXT:    smstop za
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @za_new_callee();
+  call void @za_shared_callee();
+  ret void;
+}
+
+define void @za_zt_shared_caller_normal_callee() "aarch64_pstate_za_shared" "aarch64_sme_pstate_zt0_shared" nounwind {
+; CHECK-LABEL: za_zt_shared_caller_normal_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #80
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x9, x8, x8, x9
+; CHECK-NEXT:    mov sp, x9
+; CHECK-NEXT:    sub x10, x29, #16
+; CHECK-NEXT:    sub x19, x29, #80
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x9, [x29, #-16]
+; CHECK-NEXT:    sturh w8, [x29, #-8]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    bl normal_callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    sub x0, x29, #16
+; CHECK-NEXT:    cbnz x8, .LBB2_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB2_2:
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @normal_callee();
+  ret void;
+}
+
+define void @za_zt_shared_caller_za_callee() "aarch64_pstate_za_shared" "aarch64_sme_pstate_zt0_shared" nounwind {
+; CHECK-LABEL: za_zt_shared_caller_za_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #144
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x9, x8, x8, x9
+; CHECK-NEXT:    mov sp, x9
+; CHECK-NEXT:    sub x10, x29, #16
+; CHECK-NEXT:    sub x19, x29, #80
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x9, [x29, #-16]
+; CHECK-NEXT:    sturh w8, [x29, #-8]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    bl za_new_callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    sub x0, x29, #16
+; CHECK-NEXT:    cbnz x8, .LBB3_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB3_2:
+; CHECK-NEXT:    sub x8, x29, #144
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    str zt0, [x8]
+; CHECK-NEXT:    bl za_shared_callee
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @za_new_callee();
+  call void @za_shared_callee();
+  ret void;
+}
+
+define void @za_zt_new_caller_za_preserved_callee() "aarch64_pstate_za_new" "aarch64_sme_pstate_zt0_new" nounwind {
+; CHECK-LABEL: za_zt_new_caller_za_preserved_callee:
+; CHECK:       // %bb.0: // %prelude
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsv...
[truncated]

Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kmclaughlin-arm, your patch is mimicking the existing implementation for ZA which itself was based on the previous ACLE attribute keywords. Since these are being changed in ARM-software/acle#276 (with an implementation in #76971), I think it makes sense to create LLVM IR attributes that more closely match the new Clang attributes, even if we don't treat them differently than we do for ZA (for example, we could treat __arm_in/out/inout/preserves("zt0") all as "sharing zt0" from an ABI perspective).

The new attributes also make it possible to consider ZA separately from ZT0, so I think we should implement it as such in this pull request as well. Perhaps it helps if I share my understanding of the ACLE/ABI with these new attributes:

If a function has __arm_in/out/inout/preserved(S) for S = "za" or "zt0", then the function has a Shared-ZA interface. Otherwise it has a Private-ZA interface.

For calls to a "Private ZA" interface:

  • If the caller has live ZA state and the callee is a private ZA function, then the caller sets up the lazy-save and restores it after the call.

  • If the caller has live ZT0 state and the callee is a Private ZA function, then the caller spills ZT0 and reloads it after the call.

For calls to a "Shared ZA" interface:

  • If the caller has live ZA state and the callee is a shared ZA function, but doesn't share ZA, then the caller has to spill ZA and reload it after the call.

  • If the caller has live ZT0 state and the callee is a shared ZA function, but doesn't share ZT0, then the caller has to spill ZT0 and reload it after the call.

For functions with new ZA/ZT0 state:

  • If a function is a shared ZA function and has a __arm_new("za"), then it doesn't have to commit a lazy-save, because it can't be active. It also doesn't have to enable PSTATE.ZA, because this is already active. It only has to zero ZA.

  • If a function is a shared ZA function and has a __arm_new("zt0"), then it doesn't have to enable PSTATE.ZA, because this is already active. It only has to zero ZT0.

Note that I'm not necessarily suggesting to implementing all the above cases in this one PR, it's fine to do this work incrementally.

%loadval = load i32, ptr %ptr
call void @za_shared_callee()
ret i32 %loadval
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

 declare void @bar() 
 define void @foo() "aarch64_sme_pstate_zt0_shared" nounwind {
   call void @bar()
   ret void
 }

Currently does not emit a fill for zt0 after the call to @bar.

foo:                                    // @foo
// %bb.0:
        sub     sp, sp, #80
        mov     x8, sp
        str     x30, [sp, #64]                  // 8-byte Folded Spill
        str     zt0, [x8]
        bl      bar
        ldr     x30, [sp, #64]                  // 8-byte Folded Reload
        add     sp, sp, #80
        ret

@kmclaughlin-arm kmclaughlin-arm marked this pull request as draft January 16, 2024 17:50
@kmclaughlin-arm
Copy link
Contributor Author

Closing this pull request as the following patches have been committed for handling ZT0 state based on the changes requested above:

@kmclaughlin-arm kmclaughlin-arm deleted the sme2-zt0-llvm branch June 14, 2024 12:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants