-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[AArch64][SME] Introduce CHECK_MATCHING_VL pseudo for streaming transitions #157510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
aae8b5b
ac52098
f3f31c9
3a2c02a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2940,6 +2940,63 @@ AArch64TargetLowering::EmitDynamicProbedAlloc(MachineInstr &MI, | |
return NextInst->getParent(); | ||
} | ||
|
||
MachineBasicBlock * | ||
AArch64TargetLowering::EmitCheckMatchingVL(MachineInstr &MI, | ||
MachineBasicBlock *MBB) const { | ||
MachineFunction *MF = MBB->getParent(); | ||
MachineRegisterInfo &MRI = MF->getRegInfo(); | ||
|
||
const TargetRegisterClass *RC_GPR = &AArch64::GPR64RegClass; | ||
const TargetRegisterClass *RC_GPRsp = &AArch64::GPR64spRegClass; | ||
|
||
|
||
Register RegVL_GPR = MRI.createVirtualRegister(RC_GPR); | ||
Register RegVL_GPRsp = MRI.createVirtualRegister(RC_GPRsp); // for ADDSVL src | ||
Register RegSVL_GPR = MRI.createVirtualRegister(RC_GPR); | ||
Register RegSVL_GPRsp = MRI.createVirtualRegister(RC_GPRsp); // for ADDSVL dst | ||
|
||
const TargetInstrInfo *TII = Subtarget->getInstrInfo(); | ||
DebugLoc DL = MI.getDebugLoc(); | ||
|
||
marykass-arm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
// RDVL requires GPR64, ADDSVL requires GPR64sp | ||
// We need to insert COPY instructions, these will later be removed by the | ||
marykass-arm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
// RegisterCoalescer | ||
|
||
BuildMI(*MBB, MI, DL, TII->get(AArch64::RDVLI_XI), RegVL_GPR).addImm(1); | ||
BuildMI(*MBB, MI, DL, TII->get(TargetOpcode::COPY), RegVL_GPRsp) | ||
.addReg(RegVL_GPR); | ||
|
||
BuildMI(*MBB, MI, DL, TII->get(AArch64::ADDSVL_XXI), RegSVL_GPRsp) | ||
.addReg(RegVL_GPRsp) | ||
.addImm(-1); | ||
BuildMI(*MBB, MI, DL, TII->get(TargetOpcode::COPY), RegSVL_GPR) | ||
.addReg(RegSVL_GPRsp); | ||
|
||
const BasicBlock *LLVM_BB = MBB->getBasicBlock(); | ||
MachineFunction::iterator It = ++MBB->getIterator(); | ||
MachineBasicBlock *TrapBB = MF->CreateMachineBasicBlock(LLVM_BB); | ||
MachineBasicBlock *PassBB = MF->CreateMachineBasicBlock(LLVM_BB); | ||
MF->insert(It, TrapBB); | ||
MF->insert(It, PassBB); | ||
|
||
// Continue if vector lengths match | ||
BuildMI(*MBB, MI, DL, TII->get(AArch64::CBZX)) | ||
.addReg(RegSVL_GPR) | ||
.addMBB(PassBB); | ||
|
||
// Transfer rest of current BB to PassBB | ||
PassBB->splice(PassBB->begin(), MBB, | ||
std::next(MachineBasicBlock::iterator(MI)), MBB->end()); | ||
PassBB->transferSuccessorsAndUpdatePHIs(MBB); | ||
|
||
// Trap if vector lengths mismatch | ||
BuildMI(TrapBB, DL, TII->get(AArch64::BRK)).addImm(1); | ||
|
||
MBB->addSuccessor(TrapBB); | ||
MBB->addSuccessor(PassBB); | ||
|
||
MI.eraseFromParent(); | ||
return PassBB; | ||
} | ||
|
||
MachineBasicBlock * | ||
AArch64TargetLowering::EmitTileLoad(unsigned Opc, unsigned BaseReg, | ||
MachineInstr &MI, | ||
|
@@ -3343,6 +3400,9 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( | |
case AArch64::PROBED_STACKALLOC_DYN: | ||
return EmitDynamicProbedAlloc(MI, BB); | ||
|
||
case AArch64::CHECK_MATCHING_VL_PSEUDO: | ||
return EmitCheckMatchingVL(MI, BB); | ||
|
||
case AArch64::LD1_MXIPXX_H_PSEUDO_B: | ||
return EmitTileLoad(AArch64::LD1_MXIPXX_H_B, AArch64::ZAB0, MI, BB); | ||
case AArch64::LD1_MXIPXX_H_PSEUDO_H: | ||
|
@@ -9113,14 +9173,29 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI, | |
} | ||
} | ||
|
||
SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL, | ||
bool Enable, SDValue Chain, | ||
SDValue InGlue, | ||
unsigned Condition) const { | ||
SDValue AArch64TargetLowering::changeStreamingMode( | ||
SelectionDAG &DAG, SDLoc DL, bool Enable, SDValue Chain, SDValue InGlue, | ||
unsigned Condition, bool InsertVectorLengthCheck) const { | ||
MachineFunction &MF = DAG.getMachineFunction(); | ||
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>(); | ||
FuncInfo->setHasStreamingModeChanges(true); | ||
|
||
auto GetCheckVL = [&](SDValue Chain, SDValue InGlue = SDValue()) -> SDValue { | ||
SmallVector<SDValue, 2> Ops = {Chain}; | ||
if (InGlue) | ||
Ops.push_back(InGlue); | ||
return DAG.getNode(AArch64ISD::CHECK_MATCHING_VL, DL, | ||
DAG.getVTList(MVT::Other, MVT::Glue), Ops); | ||
}; | ||
|
||
if (InsertVectorLengthCheck && Enable) { | ||
// Non-streaming -> Streaming | ||
// Insert vector length check before smstart | ||
SDValue CheckVL = GetCheckVL(Chain, InGlue); | ||
Chain = CheckVL.getValue(0); | ||
InGlue = CheckVL.getValue(1); | ||
} | ||
|
||
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); | ||
SDValue RegMask = DAG.getRegisterMask(TRI->getSMStartStopCallPreservedMask()); | ||
marykass-arm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
SDValue MSROp = | ||
|
@@ -9147,7 +9222,16 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL, | |
if (InGlue) | ||
Ops.push_back(InGlue); | ||
|
||
return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops); | ||
SDValue SMChange = | ||
DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops); | ||
|
||
if (!InsertVectorLengthCheck || Enable) | ||
return SMChange; | ||
marykass-arm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// Streaming -> Non-streaming | ||
// Insert vector length check after smstop since we cannot read VL | ||
// in streaming mode | ||
return GetCheckVL(SMChange.getValue(0), SMChange.getValue(1)); | ||
} | ||
|
||
// Emit a call to __arm_sme_save or __arm_sme_restore. | ||
|
@@ -9730,9 +9814,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, | |
|
||
SDValue InGlue; | ||
if (RequiresSMChange) { | ||
Chain = | ||
changeStreamingMode(DAG, DL, CallAttrs.callee().hasStreamingInterface(), | ||
Chain, InGlue, getSMToggleCondition(CallAttrs)); | ||
bool InsertVectorLengthCheck = | ||
(CallConv == CallingConv::AArch64_SVE_VectorCall); | ||
Chain = changeStreamingMode( | ||
marykass-arm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
DAG, DL, CallAttrs.callee().hasStreamingInterface(), Chain, InGlue, | ||
getSMToggleCondition(CallAttrs), InsertVectorLengthCheck); | ||
InGlue = Chain.getValue(1); | ||
} | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.