Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64][SME2] Extend SMEABIPass to handle functions with new ZT0 state #78848

Merged
merged 4 commits into from
Jan 22, 2024

Conversation

kmclaughlin-arm
Copy link
Contributor

updateNewZAFunctions is extended to generate the following on entry to a
function with either the "aarch64_pstate_za_new" or "arm_new_zt0" attribute:

  • Private-ZA interface: commit any active lazy-saves & enable PSTATE.ZA.
  • "aarch64_pstate_za_new": zero ZA.
  • "arm_new_zt0": zero ZT0.

Additionally, PSTATE.ZA should disabled before returning if the function
has a private-ZA interface.

…tate

updateNewZAFunctions is extended to generate the following at on
entry to a function with either the "aarch64_pstate_za_new" or
"arm_new_zt0" attributes:
 - Private-ZA interface: commit any active lazy-saves & enable PSTATE.ZA.
 - "aarch64_pstate_za_new": zero ZA.
 - "arm_new_zt0": zero ZT0.

Additionally, PSTATE.ZA should disabled before returning if the function
has a private-ZA interface.
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 20, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Kerry McLaughlin (kmclaughlin-arm)

Changes

updateNewZAFunctions is extended to generate the following on entry to a
function with either the "aarch64_pstate_za_new" or "arm_new_zt0" attribute:

  • Private-ZA interface: commit any active lazy-saves & enable PSTATE.ZA.
  • "aarch64_pstate_za_new": zero ZA.
  • "arm_new_zt0": zero ZT0.

Additionally, PSTATE.ZA should disabled before returning if the function
has a private-ZA interface.


Full diff: https://github.com/llvm/llvm-project/pull/78848.diff

3 Files Affected:

  • (modified) llvm/lib/Target/AArch64/SMEABIPass.cpp (+74-50)
  • (modified) llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp (+6-4)
  • (modified) llvm/test/CodeGen/AArch64/sme-zt0-state.ll (+107)
diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp
index 3315171798d9f1..0450e2f6f286e1 100644
--- a/llvm/lib/Target/AArch64/SMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp
@@ -40,7 +40,8 @@ struct SMEABI : public FunctionPass {
   bool runOnFunction(Function &F) override;
 
 private:
-  bool updateNewZAFunctions(Module *M, Function *F, IRBuilder<> &Builder);
+  bool updateNewStateFunctions(Module *M, Function *F, IRBuilder<> &Builder,
+                               SMEAttrs FnAttrs);
 };
 } // end anonymous namespace
 
@@ -76,56 +77,79 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder) {
                      Builder.getInt64(0));
 }
 
-/// This function generates code to commit a lazy save at the beginning of a
-/// function marked with `aarch64_pstate_za_new`. If the value read from
-/// TPIDR2_EL0 is not null on entry to the function then the lazy-saving scheme
-/// is active and we should call __arm_tpidr2_save to commit the lazy save.
-/// Additionally, PSTATE.ZA should be enabled at the beginning of the function
-/// and disabled before returning.
-bool SMEABI::updateNewZAFunctions(Module *M, Function *F,
-                                  IRBuilder<> &Builder) {
+/// This function generates code at the beginning and end of a function marked
+/// with either `aarch64_pstate_za_new` or `arm_new_zt0`.
+/// At the beginning of the function, the following code is generated:
+///  - Commit lazy-save if active   [Private-ZA Interface]
+///  - Enable PSTATE.ZA             [Private-ZA Interface]
+///  - Zero ZA                      [Has New ZA State]
+///  - Zero ZT0                     [Has New ZT0 State]
+/// At the end of the function, PSTATE.ZA is disabled if the function has a
+/// Private-ZA Interface. A function is considered to have a Private-ZA
+/// interface if it does not share ZA or ZT0.
+///
+bool SMEABI::updateNewStateFunctions(Module *M, Function *F,
+                                     IRBuilder<> &Builder, SMEAttrs FnAttrs) {
   LLVMContext &Context = F->getContext();
   BasicBlock *OrigBB = &F->getEntryBlock();
-
-  // Create the new blocks for reading TPIDR2_EL0 & enabling ZA state.
-  auto *SaveBB = OrigBB->splitBasicBlock(OrigBB->begin(), "save.za", true);
-  auto *PreludeBB = BasicBlock::Create(Context, "prelude", F, SaveBB);
-
-  // Read TPIDR2_EL0 in PreludeBB & branch to SaveBB if not 0.
-  Builder.SetInsertPoint(PreludeBB);
-  Function *TPIDR2Intr =
-      Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_get_tpidr2);
-  auto *TPIDR2 = Builder.CreateCall(TPIDR2Intr->getFunctionType(), TPIDR2Intr,
-                                    {}, "tpidr2");
-  auto *Cmp =
-      Builder.CreateCmp(ICmpInst::ICMP_NE, TPIDR2, Builder.getInt64(0), "cmp");
-  Builder.CreateCondBr(Cmp, SaveBB, OrigBB);
-
-  // Create a call __arm_tpidr2_save, which commits the lazy save.
-  Builder.SetInsertPoint(&SaveBB->back());
-  emitTPIDR2Save(M, Builder);
-
-  // Enable pstate.za at the start of the function.
   Builder.SetInsertPoint(&OrigBB->front());
-  Function *EnableZAIntr =
-      Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_za_enable);
-  Builder.CreateCall(EnableZAIntr->getFunctionType(), EnableZAIntr);
-
-  // ZA state must be zeroed upon entry to a function with NewZA
-  Function *ZeroIntr =
-      Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_zero);
-  Builder.CreateCall(ZeroIntr->getFunctionType(), ZeroIntr,
-                     Builder.getInt32(0xff));
-
-  // Before returning, disable pstate.za
-  for (BasicBlock &BB : *F) {
-    Instruction *T = BB.getTerminator();
-    if (!T || !isa<ReturnInst>(T))
-      continue;
-    Builder.SetInsertPoint(T);
-    Function *DisableZAIntr =
-        Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_za_disable);
-    Builder.CreateCall(DisableZAIntr->getFunctionType(), DisableZAIntr);
+
+  // Commit any active lazy-saves if this is a Private-ZA function. If the
+  // value read from TPIDR2_EL0 is not null on entry to the function then
+  // the lazy-saving scheme is active and we should call __arm_tpidr2_save
+  // to commit the lazy save.
+  if (FnAttrs.hasPrivateZAInterface()) {
+    // Create the new blocks for reading TPIDR2_EL0 & enabling ZA state.
+    auto *SaveBB = OrigBB->splitBasicBlock(OrigBB->begin(), "save.za", true);
+    auto *PreludeBB = BasicBlock::Create(Context, "prelude", F, SaveBB);
+
+    // Read TPIDR2_EL0 in PreludeBB & branch to SaveBB if not 0.
+    Builder.SetInsertPoint(PreludeBB);
+    Function *TPIDR2Intr =
+        Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_get_tpidr2);
+    auto *TPIDR2 = Builder.CreateCall(TPIDR2Intr->getFunctionType(), TPIDR2Intr,
+                                      {}, "tpidr2");
+    auto *Cmp = Builder.CreateCmp(ICmpInst::ICMP_NE, TPIDR2,
+                                  Builder.getInt64(0), "cmp");
+    Builder.CreateCondBr(Cmp, SaveBB, OrigBB);
+
+    // Create a call __arm_tpidr2_save, which commits the lazy save.
+    Builder.SetInsertPoint(&SaveBB->back());
+    emitTPIDR2Save(M, Builder);
+
+    // Enable pstate.za at the start of the function.
+    Builder.SetInsertPoint(&OrigBB->front());
+    Function *EnableZAIntr =
+        Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_za_enable);
+    Builder.CreateCall(EnableZAIntr->getFunctionType(), EnableZAIntr);
+  }
+
+  if (FnAttrs.hasNewZABody()) {
+    // ZA state must be zeroed upon entry to a function with NewZA
+    Function *ZeroIntr =
+        Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_zero);
+    Builder.CreateCall(ZeroIntr->getFunctionType(), ZeroIntr,
+                       Builder.getInt32(0xff));
+  }
+
+  if (FnAttrs.isNewZT0()) {
+    Function *ClearZT0Intr =
+        Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_zero_zt);
+    Builder.CreateCall(ClearZT0Intr->getFunctionType(), ClearZT0Intr,
+                       {Builder.getInt32(0)});
+  }
+
+  if (FnAttrs.hasPrivateZAInterface()) {
+    // Before returning, disable pstate.za
+    for (BasicBlock &BB : *F) {
+      Instruction *T = BB.getTerminator();
+      if (!T || !isa<ReturnInst>(T))
+        continue;
+      Builder.SetInsertPoint(T);
+      Function *DisableZAIntr =
+          Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_za_disable);
+      Builder.CreateCall(DisableZAIntr->getFunctionType(), DisableZAIntr);
+    }
   }
 
   F->addFnAttr("aarch64_expanded_pstate_za");
@@ -142,8 +166,8 @@ bool SMEABI::runOnFunction(Function &F) {
 
   bool Changed = false;
   SMEAttrs FnAttrs(F);
-  if (FnAttrs.hasNewZABody())
-    Changed |= updateNewZAFunctions(M, &F, Builder);
+  if (FnAttrs.hasNewZABody() || FnAttrs.isNewZT0())
+    Changed |= updateNewStateFunctions(M, &F, Builder, FnAttrs);
 
   return Changed;
 }
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 9693b6a664be26..c47ce42dcbd287 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -49,11 +49,13 @@ SMEAttrs::SMEAttrs(const CallBase &CB) {
 
 SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
   if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
-    Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Preserved |
-                SMEAttrs::ZA_NoLazySave);
+    Bitmask |=
+        (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Preserved |
+         SMEAttrs::ZA_NoLazySave | encodeZT0State(StateValue::Preserved));
   if (FuncName == "__arm_tpidr2_restore")
-    Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared |
-                SMEAttrs::ZA_NoLazySave);
+    Bitmask |=
+        (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared |
+         SMEAttrs::ZA_NoLazySave | encodeZT0State(StateValue::Preserved));
 }
 
 SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
index 88eaf19ec488f3..b93e865772eb9a 100644
--- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -153,3 +153,110 @@ define void @zt0_in_caller_zt0_new_callee() "aarch64_in_zt0" nounwind {
   call void @callee() "aarch64_new_zt0";
   ret void;
 }
+
+;
+; New-ZA Caller
+;
+
+; Expect commit of lazy-save if ZA is dormant
+; Expect smstart ZA & clear ZT0
+; Before return, expect smstop ZA
+define void @zt0_new_caller() "aarch64_new_zt0" nounwind {
+; CHECK-LABEL: zt0_new_caller:
+; CHECK:       // %bb.0: // %prelude
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    cbz x8, .LBB6_2
+; CHECK-NEXT:  // %bb.1: // %save.za
+; CHECK-NEXT:    bl __arm_tpidr2_save
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:  .LBB6_2:
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    zero { zt0 }
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstop za
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @callee() "aarch64_in_zt0";
+  ret void;
+}
+
+; Expect commit of lazy-save if ZA is dormant
+; Expect smstart ZA, clear ZA & clear ZT0
+; Before return, expect smstop ZA
+define void @new_za_zt0_caller() "aarch64_pstate_za_new" "aarch64_new_zt0" nounwind {
+; CHECK-LABEL: new_za_zt0_caller:
+; CHECK:       // %bb.0: // %prelude
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x8, x8, x8, x9
+; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x8, [x29, #-16]
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    cbz x8, .LBB7_2
+; CHECK-NEXT:  // %bb.1: // %save.za
+; CHECK-NEXT:    bl __arm_tpidr2_save
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:  .LBB7_2:
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    zero {za}
+; CHECK-NEXT:    zero { zt0 }
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstop za
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
+  ret void;
+}
+
+; Expect clear ZA on entry
+define void @new_za_shared_zt0_caller() "aarch64_pstate_za_new" "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: new_za_shared_zt0_caller:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x8, x8, x8, x9
+; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x8, [x29, #-16]
+; CHECK-NEXT:    zero {za}
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
+  ret void;
+}
+
+; Expect clear ZT0 on entry
+define void @shared_za_new_zt0() "aarch64_pstate_za_shared" "aarch64_new_zt0" nounwind {
+; CHECK-LABEL: shared_za_new_zt0:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x8, x8, x8, x9
+; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x8, [x29, #-16]
+; CHECK-NEXT:    zero { zt0 }
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
+  ret void;
+}

bool SMEABI::updateNewZAFunctions(Module *M, Function *F,
IRBuilder<> &Builder) {
/// This function generates code at the beginning and end of a function marked
/// with either `aarch64_pstate_za_new` or `arm_new_zt0`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// with either `aarch64_pstate_za_new` or `arm_new_zt0`.
/// with either `aarch64_pstate_za_new` or `aarch64_new_zt0`.

/// This function generates code at the beginning and end of a function marked
/// with either `aarch64_pstate_za_new` or `arm_new_zt0`.
/// At the beginning of the function, the following code is generated:
/// - Commit lazy-save if active [Private-ZA Interface]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some rationale why we need to do this for ZT0-new functions?

My understanding of the rationale for a function with only __arm_new("zt0") is as follows;

  • PSTATE.ZA may be 1 on entry to the function, which would indicate the lazy-save mechanism is active.
  • ZA is untouched by this function, so there is no need to commit the lazy-save here.
  • However, when this function calls other functions that don't share ZT0, it would always require a conditional smstop za before the call to honour the ABI (either no lazy-save is active and PSTATE must be 0, or the lazy-save is active and PSTATE is 1)
  • It is therefore easier to just commit the lazy-save at the start of the function once.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a comment similar to this explaining why we still commit the lazy-save for ZT0 new functions

}

if (FnAttrs.hasNewZABody()) {
// ZA state must be zeroed upon entry to a function with NewZA
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this comment is redundant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

SMEAttrs::ZA_NoLazySave);
Bitmask |=
(SMEAttrs::SM_Compatible | SMEAttrs::ZA_Preserved |
SMEAttrs::ZA_NoLazySave | encodeZT0State(StateValue::Preserved));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__arm_preserves("zt0") implies that __arm_tpidr2_save and __arm_sme_state are Shared-ZA functions, which they are not. I guess you're trying to avoid generating the spill/fill of ZT0 around the call to these functions, but perhaps that should be made part of ZA_NoLazySave (which then probably needs a rename).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the Preserved attribute from these functions for now, which means we generate a spill/fill of ZT0 around these calls from a aarch64_new_zt0 function and also call smstop za/smstart za when the caller does not also have ZA state.

- Removed ZT0 Preserved state from save & restore functions
; CHECK-NEXT: mrs x8, TPIDR2_EL0
; CHECK-NEXT: cbz x8, .LBB6_2
; CHECK-NEXT: // %bb.1: // %save.za
; CHECK-NEXT: mov x8, sp
; CHECK-NEXT: str zt0, [x8]
; CHECK-NEXT: smstop za
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's valid to smstop za before caling __arm_tpidr2_save if there is an active lazy-save buffer set up. This leads to a state that is not defined in the ABI.

- Return false from requiresDisablingZABeforeCall if SME_ABI_Routine is set
; CHECK-NEXT: // %bb.1: // %save.za
; CHECK-NEXT: mov x8, sp
; CHECK-NEXT: str zt0, [x8]
; CHECK-NEXT: bl __arm_tpidr2_save
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strictly speaking we could remove the spill/fill of ZT0 here, because ZT0 is overwritten below by the zero { zt0 }, but I guess that's an optimisation for a future patch.

@kmclaughlin-arm kmclaughlin-arm merged commit d4d81ac into llvm:main Jan 22, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants