Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64] NFC: Simplify the smstart/smstop pseudo. #85067

Merged
merged 2 commits into from
Mar 15, 2024

Conversation

sdesmalen-arm
Copy link
Collaborator

This is just a bit of cleanup to make the pseudo/code easier to understand. This is based on the observation that we only need to pass in a runtime value for 'pstate' if is actually needed for generating a runtime check.

This is just a bit of cleanup to make the pseudo/code easier to understand.
This is based on the observation that we only need to pass in a runtime
value for 'pstate' if is actually needed for generating a runtime check.
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 13, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Sander de Smalen (sdesmalen-arm)

Changes

This is just a bit of cleanup to make the pseudo/code easier to understand. This is based on the observation that we only need to pass in a runtime value for 'pstate' if is actually needed for generating a runtime check.


Full diff: https://github.com/llvm/llvm-project/pull/85067.diff

5 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp (+15-6)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+45-27)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+5-5)
  • (modified) llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td (+16-37)
  • (modified) llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h (+8)
diff --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
index b2c52b443753dc..3afd48f7fb299c 100644
--- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
@@ -987,7 +987,7 @@ AArch64ExpandPseudo::expandCondSMToggle(MachineBasicBlock &MBB,
   // Expand the pseudo into smstart or smstop instruction. The pseudo has the
   // following operands:
   //
-  //   MSRpstatePseudo <za|sm|both>, <0|1>, pstate.sm, expectedval, <regmask>
+  //   MSRpstatePseudo <za|sm|both>, <0|1>, pstate.sm, condition, <regmask>
   //
   // The pseudo is expanded into a conditional smstart/smstop, with a
   // check if pstate.sm (register) equals the expected value, and if not,
@@ -997,9 +997,9 @@ AArch64ExpandPseudo::expandCondSMToggle(MachineBasicBlock &MBB,
   // streaming-compatible function:
   //
   // OrigBB:
-  //   MSRpstatePseudo 3, 0, %0, 0, <regmask>             <- Conditional SMSTOP
+  //   MSRpstatePseudo 3, 0, %0, IfCallerIsStreaming, <regmask>  <- Cond SMSTOP
   //   bl @normal_callee
-  //   MSRpstatePseudo 3, 1, %0, 0, <regmask>             <- Conditional SMSTART
+  //   MSRpstatePseudo 3, 1, %0, IfCallerIsStreaming, <regmask>  <- Cond SMSTART
   //
   // ...which will be transformed into:
   //
@@ -1022,11 +1022,20 @@ AArch64ExpandPseudo::expandCondSMToggle(MachineBasicBlock &MBB,
   // We test the live value of pstate.sm and toggle pstate.sm if this is not the
   // expected value for the callee (0 for a normal callee and 1 for a streaming
   // callee).
-  auto PStateSM = MI.getOperand(2).getReg();
+  unsigned Opc;
+  switch (MI.getOperand(2).getImm()) {
+  case AArch64SME::Always:
+    llvm_unreachable("Should have matched to instruction directly");
+  case AArch64SME::IfCallerIsStreaming:
+    Opc = AArch64::TBNZW;
+    break;
+  case AArch64SME::IfCallerIsNonStreaming:
+    Opc = AArch64::TBZW;
+    break;
+  }
+  auto PStateSM = MI.getOperand(3).getReg();
   auto TRI = MBB.getParent()->getSubtarget().getRegisterInfo();
   unsigned SMReg32 = TRI->getSubReg(PStateSM, AArch64::sub_32);
-  bool IsStreamingCallee = MI.getOperand(3).getImm();
-  unsigned Opc = IsStreamingCallee ? AArch64::TBZW : AArch64::TBNZW;
   MachineInstrBuilder Tbx =
       BuildMI(MBB, MBBI, DL, TII->get(Opc)).addReg(SMReg32).addImm(0);
 
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 054311d39e7b83..90c9f1fd11ff25 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -5270,13 +5270,13 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
         AArch64ISD::SMSTART, DL, MVT::Other,
         Op->getOperand(0), // Chain
         DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
-        DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+        DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
   case Intrinsic::aarch64_sme_za_disable:
     return DAG.getNode(
         AArch64ISD::SMSTOP, DL, MVT::Other,
         Op->getOperand(0), // Chain
         DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
-        DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+        DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
   }
 }
 
@@ -7197,11 +7197,11 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
           getRegClassFor(PStateSM.getValueType().getSimpleVT()));
       FuncInfo->setPStateSMReg(Reg);
       Chain = DAG.getCopyToReg(Chain, DL, Reg, PStateSM);
-    } else {
-      PStateSM = DAG.getConstant(0, DL, MVT::i64);
-    }
-    Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue, PStateSM,
-                                /*Entry*/ true);
+      Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
+                                  AArch64SME::IfCallerIsNonStreaming, PStateSM);
+    } else
+      Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
+                                  AArch64SME::Always);
 
     // Ensure that the SMSTART happens after the CopyWithChain such that its
     // chain result is used.
@@ -7776,9 +7776,11 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
   }
 }
 
-SDValue AArch64TargetLowering::changeStreamingMode(
-    SelectionDAG &DAG, SDLoc DL, bool Enable,
-    SDValue Chain, SDValue InGlue, SDValue PStateSM, bool Entry) const {
+SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
+                                                   bool Enable, SDValue Chain,
+                                                   SDValue InGlue,
+                                                   unsigned Condition,
+                                                   SDValue PStateSM) const {
   MachineFunction &MF = DAG.getMachineFunction();
   AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
   FuncInfo->setHasStreamingModeChanges(true);
@@ -7787,10 +7789,13 @@ SDValue AArch64TargetLowering::changeStreamingMode(
   SDValue RegMask = DAG.getRegisterMask(TRI->getSMStartStopCallPreservedMask());
   SDValue MSROp =
       DAG.getTargetConstant((int32_t)AArch64SVCR::SVCRSM, DL, MVT::i32);
-
-  SDValue ExpectedSMVal =
-      DAG.getTargetConstant(Entry ? Enable : !Enable, DL, MVT::i64);
-  SmallVector<SDValue> Ops = {Chain, MSROp, PStateSM, ExpectedSMVal, RegMask};
+  SDValue ConditionOp = DAG.getTargetConstant(Condition, DL, MVT::i64);
+  SmallVector<SDValue> Ops = {Chain, MSROp, ConditionOp};
+  if (Condition != AArch64SME::Always) {
+    assert(PStateSM && "PStateSM should be defined");
+    Ops.push_back(PStateSM);
+  }
+  Ops.push_back(RegMask);
 
   if (InGlue)
     Ops.push_back(InGlue);
@@ -7799,6 +7804,19 @@ SDValue AArch64TargetLowering::changeStreamingMode(
   return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
 }
 
+static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
+                               const SMEAttrs &CalleeAttrs) {
+  if (!CallerAttrs.hasStreamingCompatibleInterface() ||
+      CallerAttrs.hasStreamingBody())
+    return AArch64SME::Always;
+  if (CalleeAttrs.hasNonStreamingInterface())
+    return AArch64SME::IfCallerIsStreaming;
+  if (CalleeAttrs.hasStreamingInterface())
+    return AArch64SME::IfCallerIsNonStreaming;
+
+  llvm_unreachable("Unsupported attributes");
+}
+
 /// LowerCall - Lower a call to a callseq_start + CALL + callseq_end chain,
 /// and add input and output parameter nodes.
 SDValue
@@ -8018,7 +8036,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
     Chain = DAG.getNode(
         AArch64ISD::SMSTOP, DL, MVT::Other, Chain,
         DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
-        DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+        DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
 
   // Adjust the stack pointer for the new arguments...
   // These operations are automatically eliminated by the prolog/epilog pass
@@ -8289,9 +8307,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
 
   SDValue InGlue;
   if (RequiresSMChange) {
-    SDValue NewChain =
-        changeStreamingMode(DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain,
-                            InGlue, PStateSM, true);
+    SDValue NewChain = changeStreamingMode(
+        DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain, InGlue,
+        getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
     Chain = NewChain.getValue(0);
     InGlue = NewChain.getValue(1);
   }
@@ -8445,8 +8463,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
 
   if (RequiresSMChange) {
     assert(PStateSM && "Expected a PStateSM to be set");
-    Result = changeStreamingMode(DAG, DL, !CalleeAttrs.hasStreamingInterface(),
-                                 Result, InGlue, PStateSM, false);
+    Result = changeStreamingMode(
+        DAG, DL, !CalleeAttrs.hasStreamingInterface(), Result, InGlue,
+        getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
   }
 
   if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
@@ -8454,7 +8473,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
     Result = DAG.getNode(
         AArch64ISD::SMSTART, DL, MVT::Other, Result,
         DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
-        DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+        DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
 
   if (ShouldPreserveZT0)
     Result =
@@ -8589,13 +8608,12 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
       Register Reg = FuncInfo->getPStateSMReg();
       assert(Reg.isValid() && "PStateSM Register is invalid");
       SDValue PStateSM = DAG.getCopyFromReg(Chain, DL, Reg, MVT::i64);
-      Chain =
-          changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
-                              /*Glue*/ SDValue(), PStateSM, /*Entry*/ false);
+      Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
+                                  /*Glue*/ SDValue(),
+                                  AArch64SME::IfCallerIsNonStreaming, PStateSM);
     } else
-      Chain = changeStreamingMode(
-          DAG, DL, /*Enable*/ false, Chain,
-          /*Glue*/ SDValue(), DAG.getConstant(1, DL, MVT::i64), /*Entry*/ true);
+      Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
+                                  /*Glue*/ SDValue(), AArch64SME::Always);
     Glue = Chain.getValue(1);
   }
 
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 68341c199e0a2a..89016cbf56e39e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -968,12 +968,12 @@ class AArch64TargetLowering : public TargetLowering {
   bool shouldExpandCttzElements(EVT VT) const override;
 
   /// If a change in streaming mode is required on entry to/return from a
-  /// function call it emits and returns the corresponding SMSTART or SMSTOP node.
-  /// \p Entry tells whether this is before/after the Call, which is necessary
-  /// because PSTATE.SM is only queried once.
+  /// function call it emits and returns the corresponding SMSTART or SMSTOP
+  /// node. \p Condition should be one of the enum values from
+  /// AArch64SME::ToggleCondition.
   SDValue changeStreamingMode(SelectionDAG &DAG, SDLoc DL, bool Enable,
-                              SDValue Chain, SDValue InGlue,
-                              SDValue PStateSM, bool Entry) const;
+                              SDValue Chain, SDValue InGlue, unsigned Condition,
+                              SDValue PStateSM = SDValue()) const;
 
   bool isVScaleKnownToBeAPowerOfTwo() const override { return true; }
 
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 2907ba74ff8108..1554f1c92b5bbb 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -10,12 +10,12 @@
 //
 //===----------------------------------------------------------------------===//
 
-def AArch64_smstart : SDNode<"AArch64ISD::SMSTART", SDTypeProfile<0, 3,
-                             [SDTCisInt<0>, SDTCisInt<0>, SDTCisInt<0>]>,
+def AArch64_smstart : SDNode<"AArch64ISD::SMSTART", SDTypeProfile<0, 2,
+                             [SDTCisInt<0>, SDTCisInt<0>]>,
                              [SDNPHasChain, SDNPSideEffect, SDNPVariadic,
                               SDNPOptInGlue, SDNPOutGlue]>;
-def AArch64_smstop  : SDNode<"AArch64ISD::SMSTOP", SDTypeProfile<0, 3,
-                             [SDTCisInt<0>, SDTCisInt<0>, SDTCisInt<0>]>,
+def AArch64_smstop  : SDNode<"AArch64ISD::SMSTOP", SDTypeProfile<0, 2,
+                             [SDTCisInt<0>, SDTCisInt<0>]>,
                              [SDNPHasChain, SDNPSideEffect, SDNPVariadic,
                               SDNPOptInGlue, SDNPOutGlue]>;
 def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3,
@@ -158,34 +158,6 @@ def : Pat<(AArch64_restore_za
             (i64 GPR64:$tpidr2_el0), (i64 GPR64sp:$tpidr2obj), (i64 texternalsym:$restore_routine)),
           (RestoreZAPseudo GPR64:$tpidr2_el0, GPR64sp:$tpidr2obj, texternalsym:$restore_routine)>;
 
-// Scenario A:
-//
-//   %pstate.before.call = 1
-//   if (%pstate.before.call != 0)
-//     smstop (pstate_za|pstate_sm)
-//   call fn()
-//   if (%pstate.before.call != 0)
-//     smstart (pstate_za|pstate_sm)
-//
-def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 1), (i64 0)),   // before call
-          (MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
-def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 1), (i64 0)),  // after call
-          (MSRpstatesvcrImm1 svcr_op:$pstate, 0b1)>;
-
-// Scenario B:
-//
-//   %pstate.before.call = 0
-//   if (%pstate.before.call != 1)
-//     smstart (pstate_za|pstate_sm)
-//   call fn()
-//   if (%pstate.before.call != 1)
-//     smstop (pstate_za|pstate_sm)
-//
-def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 0), (i64 1)),  // before call
-          (MSRpstatesvcrImm1 svcr_op:$pstate, 0b1)>;
-def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 0), (i64 1)),   // after call
-          (MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
-
 // Read and write TPIDR2_EL0
 def : Pat<(int_aarch64_sme_set_tpidr2 i64:$val),
           (MSR 0xde85, GPR64:$val)>;
@@ -230,17 +202,24 @@ defm COALESCER_BARRIER : CoalescerBarriers;
 // SME instructions.
 def MSRpstatePseudo :
   Pseudo<(outs),
-           (ins svcr_op:$pstatefield, timm0_1:$imm, GPR64:$rtpstate, timm0_1:$expected_pstate, variable_ops), []>,
+           (ins svcr_op:$pstatefield, timm0_1:$imm, timm0_31:$condition, variable_ops), []>,
     Sched<[WriteSys]> {
   let hasPostISelHook = 1;
   let Uses = [VG];
   let Defs = [VG];
 }
 
-def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 GPR64:$rtpstate), (i64 timm0_1:$expected_pstate)),
-          (MSRpstatePseudo svcr_op:$pstate, 0b1, GPR64:$rtpstate, timm0_1:$expected_pstate)>;
-def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 GPR64:$rtpstate), (i64 timm0_1:$expected_pstate)),
-          (MSRpstatePseudo svcr_op:$pstate, 0b0, GPR64:$rtpstate, timm0_1:$expected_pstate)>;
+def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 timm0_31:$condition)),
+          (MSRpstatePseudo svcr_op:$pstate, 0b1, timm0_31:$condition)>;
+def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 timm0_31:$condition)),
+          (MSRpstatePseudo svcr_op:$pstate, 0b0, timm0_31:$condition)>;
+
+// Unconditional start/stop
+def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 /*AArch64SME::Always*/0)),
+          (MSRpstatesvcrImm1 svcr_op:$pstate, 0b1)>;
+def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 /*AArch64SME::Always*/0)),
+          (MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
+
 
 //===----------------------------------------------------------------------===//
 // SME2 Instructions
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h
index ed8336a2e8ad34..f821bb527aedb8 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h
@@ -591,6 +591,14 @@ namespace AArch64BTIHint {
   #include "AArch64GenSystemOperands.inc"
 }
 
+namespace AArch64SME {
+enum ToggleCondition : unsigned {
+  Always,
+  IfCallerIsStreaming,
+  IfCallerIsNonStreaming
+};
+}
+
 namespace AArch64SE {
     enum ShiftExtSpecifiers {
         Invalid = -1,

Copy link
Contributor

@kmclaughlin-arm kmclaughlin-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a nice improvement to the pseudo, thanks @sdesmalen-arm!

@sdesmalen-arm sdesmalen-arm merged commit e639e7e into llvm:main Mar 15, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants