Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 94 additions & 8 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2940,6 +2940,63 @@ AArch64TargetLowering::EmitDynamicProbedAlloc(MachineInstr &MI,
return NextInst->getParent();
}

MachineBasicBlock *
AArch64TargetLowering::EmitCheckMatchingVL(MachineInstr &MI,
MachineBasicBlock *MBB) const {
MachineFunction *MF = MBB->getParent();
MachineRegisterInfo &MRI = MF->getRegInfo();

const TargetRegisterClass *RC_GPR = &AArch64::GPR64RegClass;
const TargetRegisterClass *RC_GPRsp = &AArch64::GPR64spRegClass;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super-nit: Rename to something explicit like MBBInsertPoint

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super-nit: Also please move this closer to its use rather than defining it at the top of the function.


Register RegVL_GPR = MRI.createVirtualRegister(RC_GPR);
Register RegVL_GPRsp = MRI.createVirtualRegister(RC_GPRsp); // for ADDSVL src
Register RegSVL_GPR = MRI.createVirtualRegister(RC_GPR);
Register RegSVL_GPRsp = MRI.createVirtualRegister(RC_GPRsp); // for ADDSVL dst

const TargetInstrInfo *TII = Subtarget->getInstrInfo();
DebugLoc DL = MI.getDebugLoc();

// RDVL requires GPR64, ADDSVL requires GPR64sp
// We need to insert COPY instructions, these will later be removed by the
// RegisterCoalescer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this comment (and some similar comments below) is redundant because it's clear from the statement below that is what it's doing. It might be more helpful to explain (once) why the COPY's are needed , e.g. something along the lines of "because ADDSVL requires GPR64sp and RDVL requires GPR64, we need to insert some COPYs that will be removed by the RegisterCoalescer".

BuildMI(*MBB, MI, DL, TII->get(AArch64::RDVLI_XI), RegVL_GPR).addImm(1);
BuildMI(*MBB, MI, DL, TII->get(TargetOpcode::COPY), RegVL_GPRsp)
.addReg(RegVL_GPR);

BuildMI(*MBB, MI, DL, TII->get(AArch64::ADDSVL_XXI), RegSVL_GPRsp)
.addReg(RegVL_GPRsp)
.addImm(-1);
BuildMI(*MBB, MI, DL, TII->get(TargetOpcode::COPY), RegSVL_GPR)
.addReg(RegSVL_GPRsp);

const BasicBlock *LLVM_BB = MBB->getBasicBlock();
MachineFunction::iterator It = ++MBB->getIterator();
MachineBasicBlock *TrapBB = MF->CreateMachineBasicBlock(LLVM_BB);
MachineBasicBlock *PassBB = MF->CreateMachineBasicBlock(LLVM_BB);
MF->insert(It, TrapBB);
MF->insert(It, PassBB);

// Continue if vector lengths match
BuildMI(*MBB, MI, DL, TII->get(AArch64::CBZX))
.addReg(RegSVL_GPR)
.addMBB(PassBB);

// Transfer rest of current BB to PassBB
PassBB->splice(PassBB->begin(), MBB,
std::next(MachineBasicBlock::iterator(MI)), MBB->end());
PassBB->transferSuccessorsAndUpdatePHIs(MBB);

// Trap if vector lengths mismatch
BuildMI(TrapBB, DL, TII->get(AArch64::BRK)).addImm(1);

MBB->addSuccessor(TrapBB);
MBB->addSuccessor(PassBB);

MI.eraseFromParent();
return PassBB;
}

MachineBasicBlock *
AArch64TargetLowering::EmitTileLoad(unsigned Opc, unsigned BaseReg,
MachineInstr &MI,
Expand Down Expand Up @@ -3343,6 +3400,9 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
case AArch64::PROBED_STACKALLOC_DYN:
return EmitDynamicProbedAlloc(MI, BB);

case AArch64::CHECK_MATCHING_VL_PSEUDO:
return EmitCheckMatchingVL(MI, BB);

case AArch64::LD1_MXIPXX_H_PSEUDO_B:
return EmitTileLoad(AArch64::LD1_MXIPXX_H_B, AArch64::ZAB0, MI, BB);
case AArch64::LD1_MXIPXX_H_PSEUDO_H:
Expand Down Expand Up @@ -9113,14 +9173,29 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
}
}

SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
bool Enable, SDValue Chain,
SDValue InGlue,
unsigned Condition) const {
SDValue AArch64TargetLowering::changeStreamingMode(
SelectionDAG &DAG, SDLoc DL, bool Enable, SDValue Chain, SDValue InGlue,
unsigned Condition, bool InsertVectorLengthCheck) const {
MachineFunction &MF = DAG.getMachineFunction();
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
FuncInfo->setHasStreamingModeChanges(true);

auto GetCheckVL = [&](SDValue Chain, SDValue InGlue = SDValue()) -> SDValue {
SmallVector<SDValue, 2> Ops = {Chain};
if (InGlue)
Ops.push_back(InGlue);
return DAG.getNode(AArch64ISD::CHECK_MATCHING_VL, DL,
DAG.getVTList(MVT::Other, MVT::Glue), Ops);
};

if (InsertVectorLengthCheck && Enable) {
// Non-streaming -> Streaming
// Insert vector length check before smstart
SDValue CheckVL = GetCheckVL(Chain, InGlue);
Chain = CheckVL.getValue(0);
InGlue = CheckVL.getValue(1);
}

const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
SDValue RegMask = DAG.getRegisterMask(TRI->getSMStartStopCallPreservedMask());
SDValue MSROp =
Expand All @@ -9147,7 +9222,16 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
if (InGlue)
Ops.push_back(InGlue);

return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
SDValue SMChange =
DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);

if (!InsertVectorLengthCheck || Enable)
return SMChange;

// Streaming -> Non-streaming
// Insert vector length check after smstop since we cannot read VL
// in streaming mode
return GetCheckVL(SMChange.getValue(0), SMChange.getValue(1));
}

// Emit a call to __arm_sme_save or __arm_sme_restore.
Expand Down Expand Up @@ -9730,9 +9814,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,

SDValue InGlue;
if (RequiresSMChange) {
Chain =
changeStreamingMode(DAG, DL, CallAttrs.callee().hasStreamingInterface(),
Chain, InGlue, getSMToggleCondition(CallAttrs));
bool InsertVectorLengthCheck =
(CallConv == CallingConv::AArch64_SVE_VectorCall);
Chain = changeStreamingMode(
DAG, DL, CallAttrs.callee().hasStreamingInterface(), Chain, InGlue,
getSMToggleCondition(CallAttrs), InsertVectorLengthCheck);
InGlue = Chain.getValue(1);
}

Expand Down
7 changes: 5 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ class AArch64TargetLowering : public TargetLowering {
MachineBasicBlock *EmitDynamicProbedAlloc(MachineInstr &MI,
MachineBasicBlock *MBB) const;

MachineBasicBlock *EmitCheckMatchingVL(MachineInstr &MI,
MachineBasicBlock *MBB) const;

MachineBasicBlock *EmitTileLoad(unsigned Opc, unsigned BaseReg,
MachineInstr &MI,
MachineBasicBlock *BB) const;
Expand Down Expand Up @@ -532,8 +535,8 @@ class AArch64TargetLowering : public TargetLowering {
/// node. \p Condition should be one of the enum values from
/// AArch64SME::ToggleCondition.
SDValue changeStreamingMode(SelectionDAG &DAG, SDLoc DL, bool Enable,
SDValue Chain, SDValue InGlue,
unsigned Condition) const;
SDValue Chain, SDValue InGlue, unsigned Condition,
bool InsertVectorLengthCheck = false) const;

bool isVScaleKnownToBeAPowerOfTwo() const override { return true; }

Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,17 @@ let usesCustomInserter = 1 in {
}
def : Pat<(i64 (AArch64EntryPStateSM)), (EntryPStateSM)>;

// Pseudo-instruction that compares the current SVE vector length (VL) with the
// streaming vector length (SVL). If the two lengths do not match, the check
// lowers to a `brk`, causing a trap.
let hasSideEffects = 1, isCodeGenOnly = 1, usesCustomInserter = 1 in
def CHECK_MATCHING_VL_PSEUDO : Pseudo<(outs), (ins), []>, Sched<[]>;

def AArch64_check_matching_vl
: SDNode<"AArch64ISD::CHECK_MATCHING_VL", SDTypeProfile<0, 0,[]>,
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue]>;
def : Pat<(AArch64_check_matching_vl), (CHECK_MATCHING_VL_PSEUDO)>;

//===----------------------------------------------------------------------===//
// Old SME ABI lowering ISD nodes/pseudos (deprecated)
//===----------------------------------------------------------------------===//
Expand Down
24 changes: 18 additions & 6 deletions llvm/test/CodeGen/AArch64/sme-callee-save-restore-pairs.ll
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,18 @@ define void @fbyte(<vscale x 16 x i8> %v) #0{
; NOPAIR-NEXT: // %bb.1:
; NOPAIR-NEXT: smstop sm
; NOPAIR-NEXT: .LBB0_2:
; NOPAIR-NEXT: rdvl x8, #1
; NOPAIR-NEXT: addsvl x8, x8, #-1
; NOPAIR-NEXT: cbz x8, .LBB0_4
; NOPAIR-NEXT: // %bb.3:
; NOPAIR-NEXT: brk #0x1
; NOPAIR-NEXT: .LBB0_4:
; NOPAIR-NEXT: ldr z0, [sp] // 16-byte Folded Reload
; NOPAIR-NEXT: bl my_func2
; NOPAIR-NEXT: tbz w19, #0, .LBB0_4
; NOPAIR-NEXT: // %bb.3:
; NOPAIR-NEXT: tbz w19, #0, .LBB0_6
; NOPAIR-NEXT: // %bb.5:
; NOPAIR-NEXT: smstart sm
; NOPAIR-NEXT: .LBB0_4:
; NOPAIR-NEXT: .LBB0_6:
; NOPAIR-NEXT: addvl sp, sp, #1
; NOPAIR-NEXT: ldr z23, [sp, #2, mul vl] // 16-byte Folded Reload
; NOPAIR-NEXT: ldr z22, [sp, #3, mul vl] // 16-byte Folded Reload
Expand Down Expand Up @@ -127,12 +133,18 @@ define void @fbyte(<vscale x 16 x i8> %v) #0{
; PAIR-NEXT: // %bb.1:
; PAIR-NEXT: smstop sm
; PAIR-NEXT: .LBB0_2:
; PAIR-NEXT: rdvl x8, #1
; PAIR-NEXT: addsvl x8, x8, #-1
; PAIR-NEXT: cbz x8, .LBB0_4
; PAIR-NEXT: // %bb.3:
; PAIR-NEXT: brk #0x1
; PAIR-NEXT: .LBB0_4:
; PAIR-NEXT: ldr z0, [sp] // 16-byte Folded Reload
; PAIR-NEXT: bl my_func2
; PAIR-NEXT: tbz w19, #0, .LBB0_4
; PAIR-NEXT: // %bb.3:
; PAIR-NEXT: tbz w19, #0, .LBB0_6
; PAIR-NEXT: // %bb.5:
; PAIR-NEXT: smstart sm
; PAIR-NEXT: .LBB0_4:
; PAIR-NEXT: .LBB0_6:
; PAIR-NEXT: addvl sp, sp, #1
; PAIR-NEXT: ldr z23, [sp, #2, mul vl] // 16-byte Folded Reload
; PAIR-NEXT: ldr z22, [sp, #3, mul vl] // 16-byte Folded Reload
Expand Down
12 changes: 11 additions & 1 deletion llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
Original file line number Diff line number Diff line change
Expand Up @@ -527,14 +527,24 @@ define void @test13(ptr %ptr) nounwind "aarch64_pstate_sm_enabled" {
; CHECK-NEXT: stp x30, x19, [sp, #80] // 16-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: mov z0.s, #0 // =0x0
; CHECK-NEXT: mov x19, x0
; CHECK-NEXT: str z0, [sp] // 16-byte Folded Spill
; CHECK-NEXT: smstop sm
; CHECK-NEXT: rdvl x8, #1
; CHECK-NEXT: addsvl x8, x8, #-1
; CHECK-NEXT: cbnz x8, .LBB14_2
; CHECK-NEXT: // %bb.1:
; CHECK-NEXT: ldr z0, [sp] // 16-byte Folded Reload
; CHECK-NEXT: mov x19, x0
; CHECK-NEXT: bl callee_farg_fret
; CHECK-NEXT: str z0, [sp] // 16-byte Folded Spill
; CHECK-NEXT: smstart sm
; CHECK-NEXT: smstop sm
; CHECK-NEXT: rdvl x8, #1
; CHECK-NEXT: addsvl x8, x8, #-1
; CHECK-NEXT: cbz x8, .LBB14_3
; CHECK-NEXT: .LBB14_2:
; CHECK-NEXT: brk #0x1
; CHECK-NEXT: .LBB14_3:
; CHECK-NEXT: ldr z0, [sp] // 16-byte Folded Reload
; CHECK-NEXT: bl callee_farg_fret
; CHECK-NEXT: str z0, [sp] // 16-byte Folded Spill
Expand Down
Loading