Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64][SME2] Preserve ZT0 state around function calls #78321

Merged
merged 8 commits into from
Jan 20, 2024

Conversation

kmclaughlin-arm
Copy link
Contributor

If a function has ZT0 state and calls a function which does not
preserve ZT0, the caller must save and restore ZT0 around the call.
If the caller shares ZT0 state and the callee is not shared ZA, we must
additionally call SMSTOP/SMSTART ZA around the call.

This patch adds new AArch64ISDNodes for spilling & filling ZT0.
Where requiresPreservingZT0 is true, ZT0 state will be preserved
across a call.

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 16, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Kerry McLaughlin (kmclaughlin-arm)

Changes

If a function has ZT0 state and calls a function which does not
preserve ZT0, the caller must save and restore ZT0 around the call.
If the caller shares ZT0 state and the callee is not shared ZA, we must
additionally call SMSTOP/SMSTART ZA around the call.

This patch adds new AArch64ISDNodes for spilling & filling ZT0.
Where requiresPreservingZT0 is true, ZT0 state will be preserved
across a call.


Full diff: https://github.com/llvm/llvm-project/pull/78321.diff

5 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+49-7)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+2)
  • (modified) llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td (+8-2)
  • (modified) llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h (+3)
  • (added) llvm/test/CodeGen/AArch64/sme-zt0-state.ll (+142)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 620872790ed8db..6df4e075e1ca50 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2341,6 +2341,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::SMSTART)
     MAKE_CASE(AArch64ISD::SMSTOP)
     MAKE_CASE(AArch64ISD::RESTORE_ZA)
+    MAKE_CASE(AArch64ISD::RESTORE_ZT)
+    MAKE_CASE(AArch64ISD::SAVE_ZT)
     MAKE_CASE(AArch64ISD::CALL)
     MAKE_CASE(AArch64ISD::ADRP)
     MAKE_CASE(AArch64ISD::ADR)
@@ -7664,6 +7666,32 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
     });
   }
 
+  SDValue ZTFrameIdx;
+  MachineFrameInfo &MFI = MF.getFrameInfo();
+  bool PreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
+
+  // If the caller has ZT0 state which will not be preserved by the callee,
+  // spill ZT0 before the call.
+  if (PreserveZT0) {
+    unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16));
+    ZTFrameIdx = DAG.getFrameIndex(
+        ZTObj,
+        DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+
+    Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other),
+                        {Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
+  }
+
+  // If caller shares ZT0 but the callee is not shared ZA, we need to stop
+  // PSTATE.ZA before the call if there is no lazy-save active.
+  bool ToggleZA = !RequiresLazySave && CallerAttrs.sharesZT0() &&
+                  CalleeAttrs.hasPrivateZAInterface();
+  if (ToggleZA)
+    Chain = DAG.getNode(
+        AArch64ISD::SMSTOP, DL, MVT::Other, Chain,
+        DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
+        DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+
   // Adjust the stack pointer for the new arguments...
   // These operations are automatically eliminated by the prolog/epilog pass
   if (!IsSibCall)
@@ -8074,14 +8102,20 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
                                  PStateSM, false);
   }
 
+  if ((RequiresLazySave && !CalleeAttrs.preservesZA()) || ToggleZA)
+    // Unconditionally resume ZA.
+    Result = DAG.getNode(
+        AArch64ISD::SMSTART, DL, MVT::Other, Result,
+        DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
+        DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+
+  if (PreserveZT0)
+    Result =
+        DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
+                    {Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
+
   if (RequiresLazySave) {
     if (!CalleeAttrs.preservesZA()) {
-      // Unconditionally resume ZA.
-      Result = DAG.getNode(
-          AArch64ISD::SMSTART, DL, MVT::Other, Result,
-          DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
-          DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
-
       // Conditionally restore the lazy save using a pseudo node.
       unsigned FI = FuncInfo->getLazySaveTPIDR2Obj();
       SDValue RegMask = DAG.getRegisterMask(
@@ -8110,7 +8144,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
         DAG.getConstant(0, DL, MVT::i64));
   }
 
-  if (RequiresSMChange || RequiresLazySave) {
+  if (RequiresSMChange || RequiresLazySave || PreserveZT0) {
     for (unsigned I = 0; I < InVals.size(); ++I) {
       // The smstart/smstop is chained as part of the call, but when the
       // resulting chain is discarded (which happens when the call is not part
@@ -23979,6 +24013,14 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
       return DAG.getMergeValues(
           {A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL);
     }
+    case Intrinsic::aarch64_sme_ldr_zt:
+      return DAG.getNode(AArch64ISD::RESTORE_ZT, SDLoc(N),
+                         DAG.getVTList(MVT::Other), N->getOperand(0),
+                         N->getOperand(2), N->getOperand(3));
+    case Intrinsic::aarch64_sme_str_zt:
+      return DAG.getNode(AArch64ISD::SAVE_ZT, SDLoc(N),
+                         DAG.getVTList(MVT::Other), N->getOperand(0),
+                         N->getOperand(2), N->getOperand(3));
     default:
       break;
     }
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 1fd639b4f7ee8f..bffee867fdf294 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -61,6 +61,8 @@ enum NodeType : unsigned {
   SMSTART,
   SMSTOP,
   RESTORE_ZA,
+  RESTORE_ZT,
+  SAVE_ZT,
 
   // Produces the full sequence of instructions for getting the thread pointer
   // offset of a variable into X0, using the TLSDesc model.
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 380f6e1fcfdaef..eeae5303a3f898 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -22,6 +22,12 @@ def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3,
                              [SDTCisInt<0>, SDTCisPtrTy<1>]>,
                              [SDNPHasChain, SDNPSideEffect, SDNPVariadic,
                               SDNPOptInGlue]>;
+def AArch64_restore_zt : SDNode<"AArch64ISD::RESTORE_ZT", SDTypeProfile<0, 2,
+                                [SDTCisInt<0>, SDTCisPtrTy<1>]>,
+                                [SDNPHasChain, SDNPSideEffect, SDNPMayLoad]>;
+def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
+                             [SDTCisInt<0>, SDTCisPtrTy<1>]>,
+                             [SDNPHasChain, SDNPSideEffect, SDNPMayStore]>;
 
 //===----------------------------------------------------------------------===//
 // Instruction naming conventions.
@@ -543,8 +549,8 @@ defm UMOPS_MPPZZ_HtoS : sme2_int_mopx_tile<"umops", 0b101, int_aarch64_sme_umops
 
 defm ZERO_T : sme2_zero_zt<"zero", 0b0001>;
 
-defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, int_aarch64_sme_ldr_zt>;
-defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, int_aarch64_sme_str_zt>;
+defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, AArch64_restore_zt>;
+defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, AArch64_save_zt>;
 
 def MOVT_XTI : sme2_movt_zt_to_scalar<"movt", 0b0011111>;
 def MOVT_TIX : sme2_movt_scalar_to_zt<"movt", 0b0011111>;
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index af2854856fb979..417dec3432a008 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -119,6 +119,9 @@ class SMEAttrs {
            State == StateValue::InOut || State == StateValue::Preserved;
   }
   bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
+  bool requiresPreservingZT0(const SMEAttrs &Callee) const {
+    return hasZT0State() && !Callee.sharesZT0();
+  }
 };
 
 } // namespace llvm
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
new file mode 100644
index 00000000000000..289794d12be171
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -0,0 +1,142 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s
+
+; Callee with no ZT state
+declare void @no_state_callee();
+
+; Callees with ZT0 state
+declare void @zt0_shared_callee() "aarch64_in_zt0";
+
+; Callees with ZA state
+
+declare void @za_shared_callee() "aarch64_pstate_za_shared";
+declare void @za_zt0_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
+
+;
+; Private-ZA Callee
+;
+
+; Expect spill & fill of ZT0 around call
+; Expect smstop/smstart za around call
+define void @zt0_in_caller_no_state_callee() "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: zt0_in_caller_no_state_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sub sp, sp, #80
+; CHECK-NEXT:    stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    mov x19, sp
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    smstop za
+; CHECK-NEXT:    bl no_state_callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    add sp, sp, #80
+; CHECK-NEXT:    ret
+  call void @no_state_callee();
+  ret void;
+}
+
+; Expect spill & fill of ZT0 around call
+; Expect setup and restore lazy-save around call
+; Expect smstart za after call
+define void @za_zt0_shared_caller_no_state_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: za_zt0_shared_caller_no_state_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #80
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x9, x8, x8, x9
+; CHECK-NEXT:    mov sp, x9
+; CHECK-NEXT:    sub x10, x29, #16
+; CHECK-NEXT:    sub x19, x29, #80
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x9, [x29, #-16]
+; CHECK-NEXT:    sturh w8, [x29, #-8]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    bl no_state_callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    sub x0, x29, #16
+; CHECK-NEXT:    cbnz x8, .LBB1_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB1_2:
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @no_state_callee();
+  ret void;
+}
+
+;
+; Shared-ZA Callee
+;
+
+; Caller and callee have shared ZT0 state, no spill/fill of ZT0 required
+define void @zt0_shared_caller_zt0_shared_callee() "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: zt0_shared_caller_zt0_shared_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    bl zt0_shared_callee
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @zt0_shared_callee();
+  ret void;
+}
+
+; Expect spill & fill of ZT0 around call
+define void @za_zt0_shared_caller_za_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: za_zt0_shared_caller_za_shared_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #80
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x8, x8, x8, x9
+; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    sub x19, x29, #80
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x8, [x29, #-16]
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    bl za_shared_callee
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @za_shared_callee();
+  ret void;
+}
+
+; Caller and callee have shared ZA & ZT0
+define void @za_zt0_shared_caller_za_zt0_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: za_zt0_shared_caller_za_zt0_shared_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x8, x8, x8, x9
+; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x8, [x29, #-16]
+; CHECK-NEXT:    bl za_zt0_shared_callee
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @za_zt0_shared_callee();
+  ret void;
+}

Copy link

github-actions bot commented Jan 18, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Comment on lines 289 to 296
ASSERT_TRUE(ZT0_New.requiresPreservingZT0(SA(SA::Normal)));
ASSERT_TRUE(ZT0_New.requiresDisablingZABeforeCall(SA(SA::Normal)));
ASSERT_TRUE(ZT0_New.requiresEnablingZAAfterCall(SA(SA::Normal)));

// ZT0 New -> ZT0 New
ASSERT_TRUE(ZT0_New.requiresPreservingZT0(ZT0_New));
ASSERT_TRUE(ZT0_New.requiresDisablingZABeforeCall(ZT0_New));
ASSERT_TRUE(ZT0_New.requiresEnablingZAAfterCall(ZT0_New));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normal and ZT0_New have the same interface, so there's no value in testing them both.

ASSERT_TRUE(ZT0_New.requiresEnablingZAAfterCall(ZT0_New));

// ZT0 New -> ZT0 Shared
ASSERT_FALSE(ZT0_New.requiresPreservingZT0(ZT0_In));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see much value in testing all these different variations of "sharing ZT0" (ZT0_In/InOut/Out/Preserves) because Callee.hasPrivateZAInterface() and Callee.sharesZT0() will return false for all of them (which is also tested elsewhere), so they're all testing the same code-path + data for these functions.

Comment on lines 316 to 364
ASSERT_TRUE(ZT0_In.requiresPreservingZT0(SA(SA::Normal)));
ASSERT_TRUE(ZT0_In.requiresDisablingZABeforeCall(SA(SA::Normal)));
ASSERT_TRUE(ZT0_In.requiresEnablingZAAfterCall(SA(SA::Normal)));

ASSERT_TRUE(ZT0_InOut.requiresPreservingZT0(SA(SA::Normal)));
ASSERT_TRUE(ZT0_InOut.requiresDisablingZABeforeCall(SA(SA::Normal)));
ASSERT_TRUE(ZT0_InOut.requiresEnablingZAAfterCall(SA(SA::Normal)));

ASSERT_TRUE(ZT0_Out.requiresPreservingZT0(SA(SA::Normal)));
ASSERT_TRUE(ZT0_Out.requiresDisablingZABeforeCall(SA(SA::Normal)));
ASSERT_TRUE(ZT0_Out.requiresEnablingZAAfterCall(SA(SA::Normal)));

ASSERT_TRUE(ZT0_Preserved.requiresPreservingZT0(SA(SA::Normal)));
ASSERT_TRUE(ZT0_Preserved.requiresDisablingZABeforeCall(SA(SA::Normal)));
ASSERT_TRUE(ZT0_Preserved.requiresEnablingZAAfterCall(SA(SA::Normal)));

// ZT0 Shared -> ZT0 Shared
ASSERT_FALSE(ZT0_In.requiresPreservingZT0(ZT0_In));
ASSERT_FALSE(ZT0_In.requiresDisablingZABeforeCall(ZT0_In));
ASSERT_FALSE(ZT0_In.requiresEnablingZAAfterCall(ZT0_In));

ASSERT_FALSE(ZT0_InOut.requiresPreservingZT0(ZT0_In));
ASSERT_FALSE(ZT0_InOut.requiresDisablingZABeforeCall(ZT0_In));
ASSERT_FALSE(ZT0_InOut.requiresEnablingZAAfterCall(ZT0_In));

ASSERT_FALSE(ZT0_Out.requiresPreservingZT0(ZT0_In));
ASSERT_FALSE(ZT0_Out.requiresDisablingZABeforeCall(ZT0_In));
ASSERT_FALSE(ZT0_Out.requiresEnablingZAAfterCall(ZT0_In));

ASSERT_FALSE(ZT0_Preserved.requiresPreservingZT0(ZT0_In));
ASSERT_FALSE(ZT0_Preserved.requiresDisablingZABeforeCall(ZT0_In));
ASSERT_FALSE(ZT0_Preserved.requiresEnablingZAAfterCall(ZT0_In));

// ZT0 Shared -> ZT0 New
ASSERT_TRUE(ZT0_In.requiresPreservingZT0(ZT0_New));
ASSERT_TRUE(ZT0_In.requiresPreservingZT0(ZT0_New));
ASSERT_TRUE(ZT0_In.requiresEnablingZAAfterCall(ZT0_New));

ASSERT_TRUE(ZT0_InOut.requiresPreservingZT0(ZT0_New));
ASSERT_TRUE(ZT0_InOut.requiresPreservingZT0(ZT0_New));
ASSERT_TRUE(ZT0_InOut.requiresEnablingZAAfterCall(ZT0_New));

ASSERT_TRUE(ZT0_Out.requiresPreservingZT0(ZT0_New));
ASSERT_TRUE(ZT0_Out.requiresPreservingZT0(ZT0_New));
ASSERT_TRUE(ZT0_Out.requiresEnablingZAAfterCall(ZT0_New));

ASSERT_TRUE(ZT0_Preserved.requiresPreservingZT0(ZT0_New));
ASSERT_TRUE(ZT0_Preserved.requiresPreservingZT0(ZT0_New));
ASSERT_TRUE(ZT0_Preserved.requiresEnablingZAAfterCall(ZT0_New));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment for these cases, they're not really testing anything that isn't already tested by the by other tests. ZT0_In/Out/InOut/Preserves/New all return true for hasZT0State(), so the codepath and data will be the same for these functions.

I think it's more valuable to add some tests for ZA (because of the lazy-save mechanism) and perhaps even some in combination with ZT0.

Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small comment on the test, but otherwise I'm happy with the PR.

Comment on lines 4 to 14
; Callee with no ZT state
declare void @no_state_callee();

; Callees with ZT0 state
declare void @zt0_shared_callee() "aarch64_in_zt0";
declare void @zt0_new_callee() "aarch64_new_zt0";

; Callees with ZA state

declare void @za_shared_callee() "aarch64_pstate_za_shared";
declare void @za_zt0_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't spot this earlier, but you can create a single declare void @callee() and then add the attributes on the callsite. That avoids the need for all thees declarations.

If a function has ZT0 state and calls a function which does not
preserve ZT0, the caller must save and restore ZT0 around the call.
If the caller shares ZT0 state and the callee is not shared ZA, we must
additionally call SMSTOP/SMSTART ZA around the call.

This patch adds new AArch64ISDNodes for spilling & filling of ZT0.
Where requiresPreservingZT0 is true, ZT0 state will be preserved
across a call.
- Added a test for an aarch64_in_zt0 caller -> aarch64_new_zt0 callee
…iresEnablingZAAfterCall

- Renamed PreserveZT0 to ShouldPreserveZT0
- Added unittests for requiresPreservingZT0, requiresDisablingZABeforeCall requiresEnablingZAAfterCall
- Used a single callee() function in sme-zt0-state.ll and added attributes at the callsite
Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kmclaughlin-arm kmclaughlin-arm merged commit a8a3711 into llvm:main Jan 20, 2024
3 of 4 checks passed
@kmclaughlin-arm kmclaughlin-arm deleted the sme2-shared-za-zt0 branch June 14, 2024 12:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants