Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SME] Stop RA from coalescing COPY instructions that transcend beyond smstart/smstop. #78294

Merged
merged 4 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,12 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
NextMBBI = MBB.end(); // The NextMBBI iterator is invalidated.
return true;
}
case AArch64::COALESCER_BARRIER_FPR16:
case AArch64::COALESCER_BARRIER_FPR32:
case AArch64::COALESCER_BARRIER_FPR64:
case AArch64::COALESCER_BARRIER_FPR128:
MI.eraseFromParent();
return true;
case AArch64::LD1B_2Z_IMM_PSEUDO:
return expandMultiVecPseudo(
MBB, MBBI, AArch64::ZPR2RegClass, AArch64::ZPR2StridedRegClass,
Expand Down
24 changes: 20 additions & 4 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2375,6 +2375,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
switch ((AArch64ISD::NodeType)Opcode) {
case AArch64ISD::FIRST_NUMBER:
break;
MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
MAKE_CASE(AArch64ISD::SMSTART)
MAKE_CASE(AArch64ISD::SMSTOP)
MAKE_CASE(AArch64ISD::RESTORE_ZA)
Expand Down Expand Up @@ -7154,13 +7155,18 @@ void AArch64TargetLowering::saveVarArgRegisters(CCState &CCInfo,
}
}

static bool isPassedInFPR(EVT VT) {
return VT.isFixedLengthVector() ||
(VT.isFloatingPoint() && !VT.isScalableVector());
}

/// LowerCallResult - Lower the result values of a call into the
/// appropriate copies out of appropriate physical registers.
SDValue AArch64TargetLowering::LowerCallResult(
SDValue Chain, SDValue InGlue, CallingConv::ID CallConv, bool isVarArg,
const SmallVectorImpl<CCValAssign> &RVLocs, const SDLoc &DL,
SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals, bool isThisReturn,
SDValue ThisVal) const {
SDValue ThisVal, bool RequiresSMChange) const {
DenseMap<unsigned, SDValue> CopiedRegs;
// Copy all of the result registers out of their specified physreg.
for (unsigned i = 0; i != RVLocs.size(); ++i) {
Expand Down Expand Up @@ -7205,6 +7211,10 @@ SDValue AArch64TargetLowering::LowerCallResult(
break;
}

if (RequiresSMChange && isPassedInFPR(VA.getValVT()))
Val = DAG.getNode(AArch64ISD::COALESCER_BARRIER, DL, Val.getValueType(),
Val);

InVals.push_back(Val);
}

Expand Down Expand Up @@ -7915,6 +7925,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
return ArgReg.Reg == VA.getLocReg();
});
} else {
// Add an extra level of indirection for streaming mode changes by
// using a pseudo copy node that cannot be rematerialised between a
// smstart/smstop and the call by the simple register coalescer.
if (RequiresSMChange && isPassedInFPR(Arg.getValueType()))
Arg = DAG.getNode(AArch64ISD::COALESCER_BARRIER, DL,
Arg.getValueType(), Arg);
RegsToPass.emplace_back(VA.getLocReg(), Arg);
RegsUsed.insert(VA.getLocReg());
const TargetOptions &Options = DAG.getTarget().Options;
Expand Down Expand Up @@ -8151,9 +8167,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,

// Handle result values, copying them out of physregs into vregs that we
// return.
SDValue Result = LowerCallResult(Chain, InGlue, CallConv, IsVarArg, RVLocs,
DL, DAG, InVals, IsThisReturn,
IsThisReturn ? OutVals[0] : SDValue());
SDValue Result = LowerCallResult(
Chain, InGlue, CallConv, IsVarArg, RVLocs, DL, DAG, InVals, IsThisReturn,
IsThisReturn ? OutVals[0] : SDValue(), RequiresSMChange);

if (!Ins.empty())
InGlue = Result.getValue(Result->getNumValues() - 1);
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ enum NodeType : unsigned {

CALL_BTI, // Function call followed by a BTI instruction.

COALESCER_BARRIER,

SMSTART,
SMSTOP,
RESTORE_ZA,
Expand Down Expand Up @@ -1026,7 +1028,7 @@ class AArch64TargetLowering : public TargetLowering {
const SmallVectorImpl<CCValAssign> &RVLocs,
const SDLoc &DL, SelectionDAG &DAG,
SmallVectorImpl<SDValue> &InVals, bool isThisReturn,
SDValue ThisVal) const;
SDValue ThisVal, bool RequiresSMChange) const;

SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
Expand Down
35 changes: 35 additions & 0 deletions llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,8 @@ bool AArch64RegisterInfo::shouldCoalesce(
MachineInstr *MI, const TargetRegisterClass *SrcRC, unsigned SubReg,
const TargetRegisterClass *DstRC, unsigned DstSubReg,
const TargetRegisterClass *NewRC, LiveIntervals &LIS) const {
MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();

if (MI->isCopy() &&
((DstRC->getID() == AArch64::GPR64RegClassID) ||
(DstRC->getID() == AArch64::GPR64commonRegClassID)) &&
Expand All @@ -1023,5 +1025,38 @@ bool AArch64RegisterInfo::shouldCoalesce(
// which implements a 32 to 64 bit zero extension
// which relies on the upper 32 bits being zeroed.
return false;

auto IsCoalescerBarrier = [](const MachineInstr &MI) {
switch (MI.getOpcode()) {
case AArch64::COALESCER_BARRIER_FPR16:
case AArch64::COALESCER_BARRIER_FPR32:
case AArch64::COALESCER_BARRIER_FPR64:
case AArch64::COALESCER_BARRIER_FPR128:
return true;
default:
return false;
}
};

// For calls that temporarily have to toggle streaming mode as part of the
// call-sequence, we need to be more careful when coalescing copy instructions
// so that we don't end up coalescing the NEON/FP result or argument register
// with a whole Z-register, such that after coalescing the register allocator
// will try to spill/reload the entire Z register.
//
// We do this by checking if the node has any defs/uses that are
// COALESCER_BARRIER pseudos. These are 'nops' in practice, but they exist to
// instruct the coalescer to avoid coalescing the copy.
if (MI->isCopy() && SubReg != DstSubReg &&
(AArch64::ZPRRegClass.hasSubClassEq(DstRC) ||
AArch64::ZPRRegClass.hasSubClassEq(SrcRC))) {
unsigned SrcReg = MI->getOperand(1).getReg();
if (any_of(MRI.def_instructions(SrcReg), IsCoalescerBarrier))
return false;
unsigned DstReg = MI->getOperand(0).getReg();
if (any_of(MRI.use_nodbg_instructions(DstReg), IsCoalescerBarrier))
return false;
}

return true;
}
22 changes: 22 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def AArch64_restore_zt : SDNode<"AArch64ISD::RESTORE_ZT", SDTypeProfile<0, 2,
def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
[SDTCisInt<0>, SDTCisPtrTy<1>]>,
[SDNPHasChain, SDNPSideEffect, SDNPMayStore]>;
def AArch64CoalescerBarrier
: SDNode<"AArch64ISD::COALESCER_BARRIER", SDTypeProfile<1, 1, []>, []>;

//===----------------------------------------------------------------------===//
// Instruction naming conventions.
Expand Down Expand Up @@ -189,6 +191,26 @@ def : Pat<(int_aarch64_sme_set_tpidr2 i64:$val),
(MSR 0xde85, GPR64:$val)>;
def : Pat<(i64 (int_aarch64_sme_get_tpidr2)),
(MRS 0xde85)>;

multiclass CoalescerBarrierPseudo<RegisterClass rc, list<ValueType> vts> {
def NAME : Pseudo<(outs rc:$dst), (ins rc:$src), []>, Sched<[]> {
let Constraints = "$dst = $src";
}
foreach vt = vts in {
def : Pat<(vt (AArch64CoalescerBarrier (vt rc:$src))),
(!cast<Instruction>(NAME) rc:$src)>;
}
}

multiclass CoalescerBarriers {
defm _FPR16 : CoalescerBarrierPseudo<FPR16, [bf16, f16]>;
defm _FPR32 : CoalescerBarrierPseudo<FPR32, [f32]>;
defm _FPR64 : CoalescerBarrierPseudo<FPR64, [f64, v8i8, v4i16, v2i32, v1i64, v4f16, v2f32, v1f64, v4bf16]>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to worry about unusual types such as v8i7? I assume these get promoted to v8i8 anyway, but I wasn't sure if we should perhaps have at least one test for them. Similarly, I wonder what happens with v8i1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you say, we should only have legal types at this point. I've added a test for v8i1, but I'm not sure how to add a test for v8i7, as there will be no legalisation for <vscale x 8 x i7>.

defm _FPR128 : CoalescerBarrierPseudo<FPR128, [f128, v16i8, v8i16, v4i32, v2i64, v8f16, v4f32, v2f64, v8bf16]>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to worry about v1i128 - or is that not even legal?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding v1i128 leads to TableGen issues that say Type set is empty for each HW mode, so i don't think we need to worry about it.

}

defm COALESCER_BARRIER : CoalescerBarriers;

} // End let Predicates = [HasSME]

// Pseudo to match to smstart/smstop. This expands:
Expand Down
20 changes: 12 additions & 8 deletions llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ define double @nonstreaming_caller_streaming_callee(double %x) nounwind noinline
; CHECK-FISEL-NEXT: bl streaming_callee
; CHECK-FISEL-NEXT: str d0, [sp, #8] // 8-byte Folded Spill
; CHECK-FISEL-NEXT: smstop sm
; CHECK-FISEL-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
; CHECK-FISEL-NEXT: adrp x8, .LCPI0_0
; CHECK-FISEL-NEXT: ldr d0, [x8, :lo12:.LCPI0_0]
; CHECK-FISEL-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
; CHECK-FISEL-NEXT: fadd d0, d1, d0
; CHECK-FISEL-NEXT: ldr x30, [sp, #80] // 8-byte Folded Reload
; CHECK-FISEL-NEXT: ldp d9, d8, [sp, #64] // 16-byte Folded Reload
Expand All @@ -49,9 +49,9 @@ define double @nonstreaming_caller_streaming_callee(double %x) nounwind noinline
; CHECK-GISEL-NEXT: bl streaming_callee
; CHECK-GISEL-NEXT: str d0, [sp, #8] // 8-byte Folded Spill
; CHECK-GISEL-NEXT: smstop sm
; CHECK-GISEL-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
; CHECK-GISEL-NEXT: mov x8, #4631107791820423168 // =0x4045000000000000
; CHECK-GISEL-NEXT: fmov d0, x8
; CHECK-GISEL-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
; CHECK-GISEL-NEXT: fadd d0, d1, d0
; CHECK-GISEL-NEXT: ldr x30, [sp, #80] // 8-byte Folded Reload
; CHECK-GISEL-NEXT: ldp d9, d8, [sp, #64] // 16-byte Folded Reload
Expand Down Expand Up @@ -82,9 +82,9 @@ define double @streaming_caller_nonstreaming_callee(double %x) nounwind noinline
; CHECK-COMMON-NEXT: bl normal_callee
; CHECK-COMMON-NEXT: str d0, [sp, #8] // 8-byte Folded Spill
; CHECK-COMMON-NEXT: smstart sm
; CHECK-COMMON-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
; CHECK-COMMON-NEXT: mov x8, #4631107791820423168 // =0x4045000000000000
; CHECK-COMMON-NEXT: fmov d0, x8
; CHECK-COMMON-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
; CHECK-COMMON-NEXT: fadd d0, d1, d0
; CHECK-COMMON-NEXT: ldr x30, [sp, #80] // 8-byte Folded Reload
; CHECK-COMMON-NEXT: ldp d9, d8, [sp, #64] // 16-byte Folded Reload
Expand All @@ -110,14 +110,16 @@ define double @locally_streaming_caller_normal_callee(double %x) nounwind noinli
; CHECK-COMMON-NEXT: str x30, [sp, #96] // 8-byte Folded Spill
; CHECK-COMMON-NEXT: str d0, [sp, #24] // 8-byte Folded Spill
; CHECK-COMMON-NEXT: smstart sm
; CHECK-COMMON-NEXT: ldr d0, [sp, #24] // 8-byte Folded Reload
; CHECK-COMMON-NEXT: str d0, [sp, #24] // 8-byte Folded Spill
; CHECK-COMMON-NEXT: smstop sm
; CHECK-COMMON-NEXT: ldr d0, [sp, #24] // 8-byte Folded Reload
; CHECK-COMMON-NEXT: bl normal_callee
; CHECK-COMMON-NEXT: str d0, [sp, #16] // 8-byte Folded Spill
; CHECK-COMMON-NEXT: smstart sm
; CHECK-COMMON-NEXT: ldr d1, [sp, #16] // 8-byte Folded Reload
; CHECK-COMMON-NEXT: mov x8, #4631107791820423168 // =0x4045000000000000
; CHECK-COMMON-NEXT: fmov d0, x8
; CHECK-COMMON-NEXT: ldr d1, [sp, #16] // 8-byte Folded Reload
; CHECK-COMMON-NEXT: fadd d0, d1, d0
; CHECK-COMMON-NEXT: str d0, [sp, #8] // 8-byte Folded Spill
; CHECK-COMMON-NEXT: smstop sm
Expand Down Expand Up @@ -329,9 +331,9 @@ define fp128 @f128_call_sm(fp128 %a, fp128 %b) "aarch64_pstate_sm_enabled" nounw
; CHECK-COMMON-NEXT: stp d11, d10, [sp, #64] // 16-byte Folded Spill
; CHECK-COMMON-NEXT: stp d9, d8, [sp, #80] // 16-byte Folded Spill
; CHECK-COMMON-NEXT: str x30, [sp, #96] // 8-byte Folded Spill
; CHECK-COMMON-NEXT: stp q0, q1, [sp] // 32-byte Folded Spill
; CHECK-COMMON-NEXT: stp q1, q0, [sp] // 32-byte Folded Spill
; CHECK-COMMON-NEXT: smstop sm
; CHECK-COMMON-NEXT: ldp q0, q1, [sp] // 32-byte Folded Reload
; CHECK-COMMON-NEXT: ldp q1, q0, [sp] // 32-byte Folded Reload
; CHECK-COMMON-NEXT: bl __addtf3
; CHECK-COMMON-NEXT: str q0, [sp, #16] // 16-byte Folded Spill
; CHECK-COMMON-NEXT: smstart sm
Expand Down Expand Up @@ -390,9 +392,9 @@ define float @frem_call_sm(float %a, float %b) "aarch64_pstate_sm_enabled" nounw
; CHECK-COMMON-NEXT: stp d11, d10, [sp, #48] // 16-byte Folded Spill
; CHECK-COMMON-NEXT: stp d9, d8, [sp, #64] // 16-byte Folded Spill
; CHECK-COMMON-NEXT: str x30, [sp, #80] // 8-byte Folded Spill
; CHECK-COMMON-NEXT: stp s0, s1, [sp, #8] // 8-byte Folded Spill
; CHECK-COMMON-NEXT: stp s1, s0, [sp, #8] // 8-byte Folded Spill
; CHECK-COMMON-NEXT: smstop sm
; CHECK-COMMON-NEXT: ldp s0, s1, [sp, #8] // 8-byte Folded Reload
; CHECK-COMMON-NEXT: ldp s1, s0, [sp, #8] // 8-byte Folded Reload
; CHECK-COMMON-NEXT: bl fmodf
; CHECK-COMMON-NEXT: str s0, [sp, #12] // 4-byte Folded Spill
; CHECK-COMMON-NEXT: smstart sm
Expand Down Expand Up @@ -420,7 +422,9 @@ define float @frem_call_sm_compat(float %a, float %b) "aarch64_pstate_sm_compati
; CHECK-COMMON-NEXT: stp x30, x19, [sp, #80] // 16-byte Folded Spill
; CHECK-COMMON-NEXT: stp s0, s1, [sp, #8] // 8-byte Folded Spill
; CHECK-COMMON-NEXT: bl __arm_sme_state
; CHECK-COMMON-NEXT: ldp s2, s0, [sp, #8] // 8-byte Folded Reload
; CHECK-COMMON-NEXT: and x19, x0, #0x1
; CHECK-COMMON-NEXT: stp s2, s0, [sp, #8] // 8-byte Folded Spill
; CHECK-COMMON-NEXT: tbz w19, #0, .LBB12_2
; CHECK-COMMON-NEXT: // %bb.1:
; CHECK-COMMON-NEXT: smstop sm
Expand Down
Loading
Loading