Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 99 additions & 57 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8094,13 +8094,76 @@ static SDValue getZT0FrameIndex(MachineFrameInfo &MFI,
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
}

// Emit a call to __arm_sme_save or __arm_sme_restore.
static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
SelectionDAG &DAG,
AArch64FunctionInfo *Info, SDLoc DL,
SDValue Chain, bool IsSave) {
MachineFunction &MF = DAG.getMachineFunction();
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
FuncInfo->setSMESaveBufferUsed();
TargetLowering::ArgListTy Args;
Args.emplace_back(
DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64),
PointerType::getUnqual(*DAG.getContext()));

RTLIB::Libcall LC =
IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE;
SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
TLI.getPointerTy(DAG.getDataLayout()));
auto *RetTy = Type::getVoidTy(*DAG.getContext());
TargetLowering::CallLoweringInfo CLI(DAG);
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
return TLI.LowerCallTo(CLI).second;
}

static SDValue emitRestoreZALazySave(SDValue Chain, SDLoc DL,
const AArch64TargetLowering &TLI,
const AArch64RegisterInfo &TRI,
AArch64FunctionInfo &FuncInfo,
SelectionDAG &DAG) {
// Conditionally restore the lazy save using a pseudo node.
RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE;
TPIDR2Object &TPIDR2 = FuncInfo.getTPIDR2Obj();
SDValue RegMask = DAG.getRegisterMask(TRI.getCallPreservedMask(
DAG.getMachineFunction(), TLI.getLibcallCallingConv(LC)));
SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
TLI.getLibcallName(LC), TLI.getPointerTy(DAG.getDataLayout()));
SDValue TPIDR2_EL0 = DAG.getNode(
ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Chain,
DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
// Copy the address of the TPIDR2 block into X0 before 'calling' the
// RESTORE_ZA pseudo.
SDValue Glue;
SDValue TPIDR2Block = DAG.getFrameIndex(
TPIDR2.FrameIndex,
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
Chain = DAG.getCopyToReg(Chain, DL, AArch64::X0, TPIDR2Block, Glue);
Chain =
DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
{Chain, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
RestoreRoutine, RegMask, Chain.getValue(1)});
// Finally reset the TPIDR2_EL0 register to 0.
Chain = DAG.getNode(
ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
DAG.getConstant(0, DL, MVT::i64));
TPIDR2.Uses++;
return Chain;
}

SDValue AArch64TargetLowering::lowerEHPadEntry(SDValue Chain, SDLoc const &DL,
SelectionDAG &DAG) const {
assert(Chain.getOpcode() == ISD::EntryToken && "Unexpected Chain value");
SDValue Glue = Chain.getValue(1);

MachineFunction &MF = DAG.getMachineFunction();
SMEAttrs SMEFnAttrs = MF.getInfo<AArch64FunctionInfo>()->getSMEFnAttrs();
auto &FuncInfo = *MF.getInfo<AArch64FunctionInfo>();
auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
const AArch64RegisterInfo &TRI = *Subtarget.getRegisterInfo();

SMEAttrs SMEFnAttrs = FuncInfo.getSMEFnAttrs();

// The following conditions are true on entry to an exception handler:
// - PSTATE.SM is 0.
Expand All @@ -8115,14 +8178,43 @@ SDValue AArch64TargetLowering::lowerEHPadEntry(SDValue Chain, SDLoc const &DL,
// These mode changes are usually optimized away in catch blocks as they
// occur before the __cxa_begin_catch (which is a non-streaming function),
// but are necessary in some cases (such as for cleanups).
//
// Additionally, if the function has ZA or ZT0 state, we must restore it.

// [COND_]SMSTART SM
if (SMEFnAttrs.hasStreamingInterfaceOrBody())
return changeStreamingMode(DAG, DL, /*Enable=*/true, Chain,
/*Glue*/ Glue, AArch64SME::Always);
Chain = changeStreamingMode(DAG, DL, /*Enable=*/true, Chain,
/*Glue*/ Glue, AArch64SME::Always);
else if (SMEFnAttrs.hasStreamingCompatibleInterface())
Chain = changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue,
AArch64SME::IfCallerIsStreaming);

if (getTM().useNewSMEABILowering())
return Chain;

if (SMEFnAttrs.hasStreamingCompatibleInterface())
return changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue,
AArch64SME::IfCallerIsStreaming);
if (SMEFnAttrs.hasAgnosticZAInterface()) {
// Restore full ZA
Chain = emitSMEStateSaveRestore(*this, DAG, &FuncInfo, DL, Chain,
/*IsSave=*/false);
} else if (SMEFnAttrs.hasZAState() || SMEFnAttrs.hasZT0State()) {
// SMSTART ZA
Chain = DAG.getNode(
AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
DAG.getTargetConstant(int32_t(AArch64SVCR::SVCRZA), DL, MVT::i32));

// Restore ZT0
if (SMEFnAttrs.hasZT0State()) {
SDValue ZT0FrameIndex =
getZT0FrameIndex(MF.getFrameInfo(), FuncInfo, DAG);
Chain =
DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
{Chain, DAG.getConstant(0, DL, MVT::i32), ZT0FrameIndex});
}

// Restore ZA
if (SMEFnAttrs.hasZAState())
Chain = emitRestoreZALazySave(Chain, DL, *this, TRI, FuncInfo, DAG);
}

return Chain;
}
Expand Down Expand Up @@ -9240,30 +9332,6 @@ SDValue AArch64TargetLowering::changeStreamingMode(
return GetCheckVL(SMChange.getValue(0), SMChange.getValue(1));
}

// Emit a call to __arm_sme_save or __arm_sme_restore.
static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
SelectionDAG &DAG,
AArch64FunctionInfo *Info, SDLoc DL,
SDValue Chain, bool IsSave) {
MachineFunction &MF = DAG.getMachineFunction();
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
FuncInfo->setSMESaveBufferUsed();
TargetLowering::ArgListTy Args;
Args.emplace_back(
DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64),
PointerType::getUnqual(*DAG.getContext()));

RTLIB::Libcall LC =
IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE;
SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
TLI.getPointerTy(DAG.getDataLayout()));
auto *RetTy = Type::getVoidTy(*DAG.getContext());
TargetLowering::CallLoweringInfo CLI(DAG);
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
return TLI.LowerCallTo(CLI).second;
}

static AArch64SME::ToggleCondition
getSMToggleCondition(const SMECallAttrs &CallAttrs) {
if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
Expand Down Expand Up @@ -10023,33 +10091,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
{Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});

if (RequiresLazySave) {
// Conditionally restore the lazy save using a pseudo node.
RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE;
TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
SDValue RegMask = DAG.getRegisterMask(
TRI->getCallPreservedMask(MF, getLibcallCallingConv(LC)));
SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
getLibcallName(LC), getPointerTy(DAG.getDataLayout()));
SDValue TPIDR2_EL0 = DAG.getNode(
ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
// Copy the address of the TPIDR2 block into X0 before 'calling' the
// RESTORE_ZA pseudo.
SDValue Glue;
SDValue TPIDR2Block = DAG.getFrameIndex(
TPIDR2.FrameIndex,
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
Result =
DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
{Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
RestoreRoutine, RegMask, Result.getValue(1)});
// Finally reset the TPIDR2_EL0 register to 0.
Result = DAG.getNode(
ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
DAG.getConstant(0, DL, MVT::i64));
TPIDR2.Uses++;
Result = emitRestoreZALazySave(Result, DL, *this, *TRI, *FuncInfo, DAG);
} else if (RequiresSaveAllZA) {
Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Result,
/*IsSave=*/false);
Expand Down
Loading