-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64] NFC: Simplify the smstart/smstop pseudo. #85067
[AArch64] NFC: Simplify the smstart/smstop pseudo. #85067
Conversation
This is just a bit of cleanup to make the pseudo/code easier to understand. This is based on the observation that we only need to pass in a runtime value for 'pstate' if is actually needed for generating a runtime check.
@llvm/pr-subscribers-backend-aarch64 Author: Sander de Smalen (sdesmalen-arm) ChangesThis is just a bit of cleanup to make the pseudo/code easier to understand. This is based on the observation that we only need to pass in a runtime value for 'pstate' if is actually needed for generating a runtime check. Full diff: https://github.com/llvm/llvm-project/pull/85067.diff 5 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
index b2c52b443753dc..3afd48f7fb299c 100644
--- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
@@ -987,7 +987,7 @@ AArch64ExpandPseudo::expandCondSMToggle(MachineBasicBlock &MBB,
// Expand the pseudo into smstart or smstop instruction. The pseudo has the
// following operands:
//
- // MSRpstatePseudo <za|sm|both>, <0|1>, pstate.sm, expectedval, <regmask>
+ // MSRpstatePseudo <za|sm|both>, <0|1>, pstate.sm, condition, <regmask>
//
// The pseudo is expanded into a conditional smstart/smstop, with a
// check if pstate.sm (register) equals the expected value, and if not,
@@ -997,9 +997,9 @@ AArch64ExpandPseudo::expandCondSMToggle(MachineBasicBlock &MBB,
// streaming-compatible function:
//
// OrigBB:
- // MSRpstatePseudo 3, 0, %0, 0, <regmask> <- Conditional SMSTOP
+ // MSRpstatePseudo 3, 0, %0, IfCallerIsStreaming, <regmask> <- Cond SMSTOP
// bl @normal_callee
- // MSRpstatePseudo 3, 1, %0, 0, <regmask> <- Conditional SMSTART
+ // MSRpstatePseudo 3, 1, %0, IfCallerIsStreaming, <regmask> <- Cond SMSTART
//
// ...which will be transformed into:
//
@@ -1022,11 +1022,20 @@ AArch64ExpandPseudo::expandCondSMToggle(MachineBasicBlock &MBB,
// We test the live value of pstate.sm and toggle pstate.sm if this is not the
// expected value for the callee (0 for a normal callee and 1 for a streaming
// callee).
- auto PStateSM = MI.getOperand(2).getReg();
+ unsigned Opc;
+ switch (MI.getOperand(2).getImm()) {
+ case AArch64SME::Always:
+ llvm_unreachable("Should have matched to instruction directly");
+ case AArch64SME::IfCallerIsStreaming:
+ Opc = AArch64::TBNZW;
+ break;
+ case AArch64SME::IfCallerIsNonStreaming:
+ Opc = AArch64::TBZW;
+ break;
+ }
+ auto PStateSM = MI.getOperand(3).getReg();
auto TRI = MBB.getParent()->getSubtarget().getRegisterInfo();
unsigned SMReg32 = TRI->getSubReg(PStateSM, AArch64::sub_32);
- bool IsStreamingCallee = MI.getOperand(3).getImm();
- unsigned Opc = IsStreamingCallee ? AArch64::TBZW : AArch64::TBNZW;
MachineInstrBuilder Tbx =
BuildMI(MBB, MBBI, DL, TII->get(Opc)).addReg(SMReg32).addImm(0);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 054311d39e7b83..90c9f1fd11ff25 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -5270,13 +5270,13 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
AArch64ISD::SMSTART, DL, MVT::Other,
Op->getOperand(0), // Chain
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
- DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+ DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
case Intrinsic::aarch64_sme_za_disable:
return DAG.getNode(
AArch64ISD::SMSTOP, DL, MVT::Other,
Op->getOperand(0), // Chain
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
- DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+ DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
}
}
@@ -7197,11 +7197,11 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
getRegClassFor(PStateSM.getValueType().getSimpleVT()));
FuncInfo->setPStateSMReg(Reg);
Chain = DAG.getCopyToReg(Chain, DL, Reg, PStateSM);
- } else {
- PStateSM = DAG.getConstant(0, DL, MVT::i64);
- }
- Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue, PStateSM,
- /*Entry*/ true);
+ Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
+ AArch64SME::IfCallerIsNonStreaming, PStateSM);
+ } else
+ Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
+ AArch64SME::Always);
// Ensure that the SMSTART happens after the CopyWithChain such that its
// chain result is used.
@@ -7776,9 +7776,11 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
}
}
-SDValue AArch64TargetLowering::changeStreamingMode(
- SelectionDAG &DAG, SDLoc DL, bool Enable,
- SDValue Chain, SDValue InGlue, SDValue PStateSM, bool Entry) const {
+SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
+ bool Enable, SDValue Chain,
+ SDValue InGlue,
+ unsigned Condition,
+ SDValue PStateSM) const {
MachineFunction &MF = DAG.getMachineFunction();
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
FuncInfo->setHasStreamingModeChanges(true);
@@ -7787,10 +7789,13 @@ SDValue AArch64TargetLowering::changeStreamingMode(
SDValue RegMask = DAG.getRegisterMask(TRI->getSMStartStopCallPreservedMask());
SDValue MSROp =
DAG.getTargetConstant((int32_t)AArch64SVCR::SVCRSM, DL, MVT::i32);
-
- SDValue ExpectedSMVal =
- DAG.getTargetConstant(Entry ? Enable : !Enable, DL, MVT::i64);
- SmallVector<SDValue> Ops = {Chain, MSROp, PStateSM, ExpectedSMVal, RegMask};
+ SDValue ConditionOp = DAG.getTargetConstant(Condition, DL, MVT::i64);
+ SmallVector<SDValue> Ops = {Chain, MSROp, ConditionOp};
+ if (Condition != AArch64SME::Always) {
+ assert(PStateSM && "PStateSM should be defined");
+ Ops.push_back(PStateSM);
+ }
+ Ops.push_back(RegMask);
if (InGlue)
Ops.push_back(InGlue);
@@ -7799,6 +7804,19 @@ SDValue AArch64TargetLowering::changeStreamingMode(
return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
}
+static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
+ const SMEAttrs &CalleeAttrs) {
+ if (!CallerAttrs.hasStreamingCompatibleInterface() ||
+ CallerAttrs.hasStreamingBody())
+ return AArch64SME::Always;
+ if (CalleeAttrs.hasNonStreamingInterface())
+ return AArch64SME::IfCallerIsStreaming;
+ if (CalleeAttrs.hasStreamingInterface())
+ return AArch64SME::IfCallerIsNonStreaming;
+
+ llvm_unreachable("Unsupported attributes");
+}
+
/// LowerCall - Lower a call to a callseq_start + CALL + callseq_end chain,
/// and add input and output parameter nodes.
SDValue
@@ -8018,7 +8036,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
Chain = DAG.getNode(
AArch64ISD::SMSTOP, DL, MVT::Other, Chain,
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
- DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+ DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
// Adjust the stack pointer for the new arguments...
// These operations are automatically eliminated by the prolog/epilog pass
@@ -8289,9 +8307,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
SDValue InGlue;
if (RequiresSMChange) {
- SDValue NewChain =
- changeStreamingMode(DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain,
- InGlue, PStateSM, true);
+ SDValue NewChain = changeStreamingMode(
+ DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain, InGlue,
+ getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
Chain = NewChain.getValue(0);
InGlue = NewChain.getValue(1);
}
@@ -8445,8 +8463,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
if (RequiresSMChange) {
assert(PStateSM && "Expected a PStateSM to be set");
- Result = changeStreamingMode(DAG, DL, !CalleeAttrs.hasStreamingInterface(),
- Result, InGlue, PStateSM, false);
+ Result = changeStreamingMode(
+ DAG, DL, !CalleeAttrs.hasStreamingInterface(), Result, InGlue,
+ getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
}
if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
@@ -8454,7 +8473,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
Result = DAG.getNode(
AArch64ISD::SMSTART, DL, MVT::Other, Result,
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
- DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+ DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
if (ShouldPreserveZT0)
Result =
@@ -8589,13 +8608,12 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
Register Reg = FuncInfo->getPStateSMReg();
assert(Reg.isValid() && "PStateSM Register is invalid");
SDValue PStateSM = DAG.getCopyFromReg(Chain, DL, Reg, MVT::i64);
- Chain =
- changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
- /*Glue*/ SDValue(), PStateSM, /*Entry*/ false);
+ Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
+ /*Glue*/ SDValue(),
+ AArch64SME::IfCallerIsNonStreaming, PStateSM);
} else
- Chain = changeStreamingMode(
- DAG, DL, /*Enable*/ false, Chain,
- /*Glue*/ SDValue(), DAG.getConstant(1, DL, MVT::i64), /*Entry*/ true);
+ Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
+ /*Glue*/ SDValue(), AArch64SME::Always);
Glue = Chain.getValue(1);
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 68341c199e0a2a..89016cbf56e39e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -968,12 +968,12 @@ class AArch64TargetLowering : public TargetLowering {
bool shouldExpandCttzElements(EVT VT) const override;
/// If a change in streaming mode is required on entry to/return from a
- /// function call it emits and returns the corresponding SMSTART or SMSTOP node.
- /// \p Entry tells whether this is before/after the Call, which is necessary
- /// because PSTATE.SM is only queried once.
+ /// function call it emits and returns the corresponding SMSTART or SMSTOP
+ /// node. \p Condition should be one of the enum values from
+ /// AArch64SME::ToggleCondition.
SDValue changeStreamingMode(SelectionDAG &DAG, SDLoc DL, bool Enable,
- SDValue Chain, SDValue InGlue,
- SDValue PStateSM, bool Entry) const;
+ SDValue Chain, SDValue InGlue, unsigned Condition,
+ SDValue PStateSM = SDValue()) const;
bool isVScaleKnownToBeAPowerOfTwo() const override { return true; }
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 2907ba74ff8108..1554f1c92b5bbb 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -10,12 +10,12 @@
//
//===----------------------------------------------------------------------===//
-def AArch64_smstart : SDNode<"AArch64ISD::SMSTART", SDTypeProfile<0, 3,
- [SDTCisInt<0>, SDTCisInt<0>, SDTCisInt<0>]>,
+def AArch64_smstart : SDNode<"AArch64ISD::SMSTART", SDTypeProfile<0, 2,
+ [SDTCisInt<0>, SDTCisInt<0>]>,
[SDNPHasChain, SDNPSideEffect, SDNPVariadic,
SDNPOptInGlue, SDNPOutGlue]>;
-def AArch64_smstop : SDNode<"AArch64ISD::SMSTOP", SDTypeProfile<0, 3,
- [SDTCisInt<0>, SDTCisInt<0>, SDTCisInt<0>]>,
+def AArch64_smstop : SDNode<"AArch64ISD::SMSTOP", SDTypeProfile<0, 2,
+ [SDTCisInt<0>, SDTCisInt<0>]>,
[SDNPHasChain, SDNPSideEffect, SDNPVariadic,
SDNPOptInGlue, SDNPOutGlue]>;
def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3,
@@ -158,34 +158,6 @@ def : Pat<(AArch64_restore_za
(i64 GPR64:$tpidr2_el0), (i64 GPR64sp:$tpidr2obj), (i64 texternalsym:$restore_routine)),
(RestoreZAPseudo GPR64:$tpidr2_el0, GPR64sp:$tpidr2obj, texternalsym:$restore_routine)>;
-// Scenario A:
-//
-// %pstate.before.call = 1
-// if (%pstate.before.call != 0)
-// smstop (pstate_za|pstate_sm)
-// call fn()
-// if (%pstate.before.call != 0)
-// smstart (pstate_za|pstate_sm)
-//
-def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 1), (i64 0)), // before call
- (MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
-def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 1), (i64 0)), // after call
- (MSRpstatesvcrImm1 svcr_op:$pstate, 0b1)>;
-
-// Scenario B:
-//
-// %pstate.before.call = 0
-// if (%pstate.before.call != 1)
-// smstart (pstate_za|pstate_sm)
-// call fn()
-// if (%pstate.before.call != 1)
-// smstop (pstate_za|pstate_sm)
-//
-def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 0), (i64 1)), // before call
- (MSRpstatesvcrImm1 svcr_op:$pstate, 0b1)>;
-def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 0), (i64 1)), // after call
- (MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
-
// Read and write TPIDR2_EL0
def : Pat<(int_aarch64_sme_set_tpidr2 i64:$val),
(MSR 0xde85, GPR64:$val)>;
@@ -230,17 +202,24 @@ defm COALESCER_BARRIER : CoalescerBarriers;
// SME instructions.
def MSRpstatePseudo :
Pseudo<(outs),
- (ins svcr_op:$pstatefield, timm0_1:$imm, GPR64:$rtpstate, timm0_1:$expected_pstate, variable_ops), []>,
+ (ins svcr_op:$pstatefield, timm0_1:$imm, timm0_31:$condition, variable_ops), []>,
Sched<[WriteSys]> {
let hasPostISelHook = 1;
let Uses = [VG];
let Defs = [VG];
}
-def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 GPR64:$rtpstate), (i64 timm0_1:$expected_pstate)),
- (MSRpstatePseudo svcr_op:$pstate, 0b1, GPR64:$rtpstate, timm0_1:$expected_pstate)>;
-def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 GPR64:$rtpstate), (i64 timm0_1:$expected_pstate)),
- (MSRpstatePseudo svcr_op:$pstate, 0b0, GPR64:$rtpstate, timm0_1:$expected_pstate)>;
+def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 timm0_31:$condition)),
+ (MSRpstatePseudo svcr_op:$pstate, 0b1, timm0_31:$condition)>;
+def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 timm0_31:$condition)),
+ (MSRpstatePseudo svcr_op:$pstate, 0b0, timm0_31:$condition)>;
+
+// Unconditional start/stop
+def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 /*AArch64SME::Always*/0)),
+ (MSRpstatesvcrImm1 svcr_op:$pstate, 0b1)>;
+def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 /*AArch64SME::Always*/0)),
+ (MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
+
//===----------------------------------------------------------------------===//
// SME2 Instructions
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h
index ed8336a2e8ad34..f821bb527aedb8 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h
@@ -591,6 +591,14 @@ namespace AArch64BTIHint {
#include "AArch64GenSystemOperands.inc"
}
+namespace AArch64SME {
+enum ToggleCondition : unsigned {
+ Always,
+ IfCallerIsStreaming,
+ IfCallerIsNonStreaming
+};
+}
+
namespace AArch64SE {
enum ShiftExtSpecifiers {
Invalid = -1,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a nice improvement to the pseudo, thanks @sdesmalen-arm!
This is just a bit of cleanup to make the pseudo/code easier to understand. This is based on the observation that we only need to pass in a runtime value for 'pstate' if is actually needed for generating a runtime check.