Skip to content

Commit

Permalink
[AArch64][SME] Add support for arm_locally_streaming functions.
Browse files Browse the repository at this point in the history
Functions with `aarch64_sme_pstatesm_body` will emit a SMSTART at the start
of the function, and a SMSTOP at the end of the function, such that all
operations use the right value for vscale.

Because the placement of these nodes is critically important (i.e. no
vscale-dependent operations should be done before SMSTART has been issued),
we require glueing the CopyFromReg to the Entry node such that we can
insert the SMSTART as part of that glued chain.

More details about the SME attributes and design can be found
in D131562.

Reviewed By: aemerson

Differential Revision: https://reviews.llvm.org/D131582
  • Loading branch information
sdesmalen-arm committed Oct 14, 2022
1 parent b3f1d58 commit 02df03c
Show file tree
Hide file tree
Showing 14 changed files with 378 additions and 37 deletions.
12 changes: 6 additions & 6 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Expand Up @@ -237,6 +237,12 @@ class SelectionDAG {
ProfileSummaryInfo *PSI = nullptr;
BlockFrequencyInfo *BFI = nullptr;

/// List of non-single value types.
FoldingSet<SDVTListNode> VTListMap;

/// Pool allocation for misc. objects that are created once per SelectionDAG.
BumpPtrAllocator Allocator;

/// The starting token.
SDNode EntryNode;

Expand All @@ -263,9 +269,6 @@ class SelectionDAG {
BumpPtrAllocator OperandAllocator;
ArrayRecycler<SDUse> OperandRecycler;

/// Pool allocation for misc. objects that are created once per SelectionDAG.
BumpPtrAllocator Allocator;

/// Tracks dbg_value and dbg_label information through SDISel.
SDDbgInfo *DbgInfo;

Expand Down Expand Up @@ -2281,9 +2284,6 @@ class SelectionDAG {
SDNode *FindNodeOrInsertPos(const FoldingSetNodeID &ID, const SDLoc &DL,
void *&InsertPos);

/// List of non-single value types.
FoldingSet<SDVTListNode> VTListMap;

/// Maps to auto-CSE operations.
std::vector<CondCodeSDNode*> CondCodeNodes;

Expand Down
1 change: 0 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
Expand Up @@ -1162,7 +1162,6 @@ EmitSpecialNode(SDNode *Node, bool IsClone, bool IsCloned,
#endif
llvm_unreachable("This target-independent node should have been selected!");
case ISD::EntryToken:
llvm_unreachable("EntryToken should have been excluded from the schedule!");
case ISD::MERGE_VALUES:
case ISD::TokenFactor: // fall thru
break;
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Expand Up @@ -1275,7 +1275,7 @@ Align SelectionDAG::getEVTAlign(EVT VT) const {
// EntryNode could meaningfully have debug info if we can find it...
SelectionDAG::SelectionDAG(const TargetMachine &tm, CodeGenOpt::Level OL)
: TM(tm), OptLevel(OL),
EntryNode(ISD::EntryToken, 0, DebugLoc(), getVTList(MVT::Other)),
EntryNode(ISD::EntryToken, 0, DebugLoc(), getVTList(MVT::Other, MVT::Glue)),
Root(getEntryNode()) {
InsertNode(&EntryNode);
DbgInfo = new SDDbgInfo();
Expand Down
59 changes: 57 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -6037,6 +6037,13 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
(void)Res;
}

SMEAttrs Attrs(MF.getFunction());
bool IsLocallyStreaming =
!Attrs.hasStreamingInterface() && Attrs.hasStreamingBody();
assert(Chain.getOpcode() == ISD::EntryToken && "Unexpected Chain value");
SDValue Glue = Chain.getValue(1);

SmallVector<SDValue, 16> ArgValues;
unsigned ExtraArgLocs = 0;
for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
CCValAssign &VA = ArgLocs[i - ExtraArgLocs];
Expand Down Expand Up @@ -6091,7 +6098,22 @@ SDValue AArch64TargetLowering::LowerFormalArguments(

// Transform the arguments in physical registers into virtual ones.
Register Reg = MF.addLiveIn(VA.getLocReg(), RC);
ArgValue = DAG.getCopyFromReg(Chain, DL, Reg, RegVT);

if (IsLocallyStreaming) {
// LocallyStreamingFunctions must insert the SMSTART in the correct
// position, so we use Glue to ensure no instructions can be scheduled
// between the chain of:
// t0: ch,glue = EntryNode
// t1: res,ch,glue = CopyFromReg
// ...
// tn: res,ch,glue = CopyFromReg t(n-1), ..
// t(n+1): ch, glue = SMSTART t0:0, ...., tn:2
// ^^^^^^
// This will be the new Chain/Root node.
ArgValue = DAG.getCopyFromReg(Chain, DL, Reg, RegVT, Glue);
Glue = ArgValue.getValue(2);
} else
ArgValue = DAG.getCopyFromReg(Chain, DL, Reg, RegVT);

// If this is an 8, 16 or 32-bit value, it is really passed promoted
// to 64 bits. Insert an assert[sz]ext to capture this, then
Expand Down Expand Up @@ -6245,6 +6267,27 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
}
assert((ArgLocs.size() + ExtraArgLocs) == Ins.size());

// Insert the SMSTART if this is a locally streaming function and
// make sure it is Glued to the last CopyFromReg value.
if (IsLocallyStreaming) {
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
Chain = DAG.getNode(
AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue),
{DAG.getRoot(),
DAG.getTargetConstant((int32_t)AArch64SVCR::SVCRSM, DL, MVT::i32),
DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64),
DAG.getRegisterMask(TRI->getSMStartStopCallPreservedMask()), Glue});
// Ensure that the SMSTART happens after the CopyWithChain such that its
// chain result is used.
for (unsigned I=0; I<InVals.size(); ++I) {
Register Reg = MF.getRegInfo().createVirtualRegister(
getRegClassFor(InVals[I].getValueType().getSimpleVT()));
SDValue X = DAG.getCopyToReg(Chain, DL, Reg, InVals[I]);
InVals[I] = DAG.getCopyFromReg(X, DL, Reg,
InVals[I].getValueType());
}
}

// varargs
if (isVarArg) {
if (!Subtarget->isTargetDarwin() || IsWin64) {
Expand Down Expand Up @@ -7485,6 +7528,19 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
}
}

const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();

// Emit SMSTOP before returning from a locally streaming function
SMEAttrs FuncAttrs(MF.getFunction());
if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) {
Chain = DAG.getNode(
AArch64ISD::SMSTOP, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
DAG.getTargetConstant((int32_t)AArch64SVCR::SVCRSM, DL, MVT::i32),
DAG.getConstant(1, DL, MVT::i64), DAG.getConstant(0, DL, MVT::i64),
DAG.getRegisterMask(TRI->getSMStartStopCallPreservedMask()));
Flag = Chain.getValue(1);
}

SmallVector<SDValue, 4> RetOps(1, Chain);
for (auto &RetVal : RetVals) {
Chain = DAG.getCopyToReg(Chain, DL, RetVal.first, RetVal.second, Flag);
Expand All @@ -7509,7 +7565,6 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
DAG.getRegister(RetValReg, getPointerTy(DAG.getDataLayout())));
}

const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
const MCPhysReg *I = TRI->getCalleeSavedRegsViaCopy(&MF);
if (I) {
for (; *I; ++I) {
Expand Down
24 changes: 18 additions & 6 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
Expand Up @@ -4256,6 +4256,8 @@ static void emitFrameOffsetAdj(MachineBasicBlock &MBB,
break;
case AArch64::ADDVL_XXI:
case AArch64::ADDPL_XXI:
case AArch64::ADDSVL_XXI:
case AArch64::ADDSPL_XXI:
MaxEncoding = 31;
ShiftSize = 0;
if (Offset < 0) {
Expand All @@ -4270,9 +4272,9 @@ static void emitFrameOffsetAdj(MachineBasicBlock &MBB,

// `Offset` can be in bytes or in "scalable bytes".
int VScale = 1;
if (Opc == AArch64::ADDVL_XXI)
if (Opc == AArch64::ADDVL_XXI || Opc == AArch64::ADDSVL_XXI)
VScale = 16;
else if (Opc == AArch64::ADDPL_XXI)
else if (Opc == AArch64::ADDPL_XXI || Opc == AArch64::ADDSPL_XXI)
VScale = 2;

// FIXME: If the offset won't fit in 24-bits, compute the offset into a
Expand Down Expand Up @@ -4369,6 +4371,14 @@ void llvm::emitFrameOffset(MachineBasicBlock &MBB,
bool NeedsWinCFI, bool *HasWinCFI,
bool EmitCFAOffset, StackOffset CFAOffset,
unsigned FrameReg) {
// If a function is marked as arm_locally_streaming, then the runtime value of
// vscale in the prologue/epilogue is different the runtime value of vscale
// in the function's body. To avoid having to consider multiple vscales,
// we can use `addsvl` to allocate any scalable stack-slots, which under
// most circumstances will be only locals, not callee-save slots.
const Function &F = MBB.getParent()->getFunction();
bool UseSVL = F.hasFnAttribute("aarch64_pstate_sm_body");

int64_t Bytes, NumPredicateVectors, NumDataVectors;
AArch64InstrInfo::decomposeStackOffsetForFrameOffsets(
Offset, Bytes, NumPredicateVectors, NumDataVectors);
Expand Down Expand Up @@ -4399,17 +4409,19 @@ void llvm::emitFrameOffset(MachineBasicBlock &MBB,

if (NumDataVectors) {
emitFrameOffsetAdj(MBB, MBBI, DL, DestReg, SrcReg, NumDataVectors,
AArch64::ADDVL_XXI, TII, Flag, NeedsWinCFI, nullptr,
EmitCFAOffset, CFAOffset, FrameReg);
UseSVL ? AArch64::ADDSVL_XXI : AArch64::ADDVL_XXI,
TII, Flag, NeedsWinCFI, nullptr, EmitCFAOffset,
CFAOffset, FrameReg);
CFAOffset += StackOffset::getScalable(-NumDataVectors * 16);
SrcReg = DestReg;
}

if (NumPredicateVectors) {
assert(DestReg != AArch64::SP && "Unaligned access to SP");
emitFrameOffsetAdj(MBB, MBBI, DL, DestReg, SrcReg, NumPredicateVectors,
AArch64::ADDPL_XXI, TII, Flag, NeedsWinCFI, nullptr,
EmitCFAOffset, CFAOffset, FrameReg);
UseSVL ? AArch64::ADDSPL_XXI : AArch64::ADDPL_XXI,
TII, Flag, NeedsWinCFI, nullptr, EmitCFAOffset,
CFAOffset, FrameReg);
}
}

Expand Down
10 changes: 10 additions & 0 deletions llvm/test/CodeGen/AArch64/sme-get-pstatesm.ll
Expand Up @@ -22,7 +22,17 @@ define i64 @get_pstatesm_streaming() nounwind "aarch64_pstate_sm_enabled" {
define i64 @get_pstatesm_locally_streaming() nounwind "aarch64_pstate_sm_body" {
; CHECK-LABEL: get_pstatesm_locally_streaming:
; CHECK: // %bb.0:
; CHECK-NEXT: stp d15, d14, [sp, #-64]! // 16-byte Folded Spill
; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
; CHECK-NEXT: smstart sm
; CHECK-NEXT: smstop sm
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
; CHECK-NEXT: mov w0, #1
; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
; CHECK-NEXT: ldp d15, d14, [sp], #64 // 16-byte Folded Reload
; CHECK-NEXT: ret
%pstate = call i64 @llvm.aarch64.sme.get.pstatesm()
ret i64 %pstate
Expand Down

0 comments on commit 02df03c

Please sign in to comment.